Implementation of 2nd order gradient in maml
vivkul opened this issue · comments
In line 126 of few_shot/maml.py:
meta_batch_loss.backward()
however, the graph for meta_batch_loss does not use model parameters. Each loss used to get meta_batch_loss was obtained using fast_weights. So while doing backwards over meta_batch_loss, gradients would be obtained over fast_weights not model parameters.
It may be corrected by separately doing autograd.grad
over model.named_parameters().values()
for each of the loss and then summing it up, and then using hook to update the model parameters.
What do you think?
I use the hook method for 1st order MAML i.e. I use the sum of the gradients on the fast weights to update the slow weights.
Line 100 in 672de83
However I do not believe this is the correct way to produce the 2nd order update. When calling meta_batch_loss.backward()
we are in fact calculating the gradients over the model parameters as Pytorch is capable of doing autograd through the gradient update steps on fast_weights
back to the original model weights. This is a double backwards operation (as it backpropagates through another loss.backward() operation) and I have a test that explicitly checks for the existence of this operation in the autograd graph.
Line 102 in 672de83
Another reason I'm sure of this is that if Pytorch did not calculate gradients on the original model weights then the following line would do nothing and the model wouldn't train at all.
Line 127 in 672de83
The optimiser
object on which we call .step()
in the line above is in fact the optimiser created on this line
Line 84 in 672de83
which is tied to the original model parameters.