oscarknagg / few-shot

Repository for few-shot learning machine learning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implementation of 2nd order gradient in maml

vivkul opened this issue · comments

In line 126 of few_shot/maml.py:

meta_batch_loss.backward()

however, the graph for meta_batch_loss does not use model parameters. Each loss used to get meta_batch_loss was obtained using fast_weights. So while doing backwards over meta_batch_loss, gradients would be obtained over fast_weights not model parameters.

It may be corrected by separately doing autograd.grad over model.named_parameters().values() for each of the loss and then summing it up, and then using hook to update the model parameters.

What do you think?

I use the hook method for 1st order MAML i.e. I use the sum of the gradients on the fast weights to update the slow weights.

hooks = []

However I do not believe this is the correct way to produce the 2nd order update. When calling meta_batch_loss.backward() we are in fact calculating the gradients over the model parameters as Pytorch is capable of doing autograd through the gradient update steps on fast_weights back to the original model weights. This is a double backwards operation (as it backpropagates through another loss.backward() operation) and I have a test that explicitly checks for the existence of this operation in the autograd graph.

def test_second_order(self):

Another reason I'm sure of this is that if Pytorch did not calculate gradients on the original model weights then the following line would do nothing and the model wouldn't train at all.

optimiser.step()

The optimiser object on which we call .step() in the line above is in fact the optimiser created on this line

meta_optimiser = torch.optim.Adam(meta_model.parameters(), lr=args.meta_lr)

which is tied to the original model parameters.