oscarknagg / few-shot

Repository for few-shot learning machine learning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

loss.backward(retain_graph=True) in line 131 in maml.py, is it required ?

sudarshan1994 opened this issue · comments

It seems to me that loss.backward(retain_graph= True) is computing the gradients, but the code is computing it again a few lines beneath it using gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph), so I feel like some duplication of work is being done here.

Again I am not sure If I am correct, any response would be appreciated !

Thanks for your time

I believe you're correct here. In fact while I looked at that piece of code I realised that the following snippet is unecessary during 2nd order MAML

gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
task_gradients.append(named_grads)

I've made a branch with these changes that I'll merge to master when I'm sure they're correct (https://github.com/oscarknagg/few-shot/tree/maml-efficiency)

oh cool thanks for getting back, I ran it with and without those lines, atleast for the first order MAML, it seems to work !

Hey @oscarknagg , it seems that you've never merged that branch (https://github.com/oscarknagg/few-shot/tree/maml-efficiency). Could you tell me whether it has to do with this idea being wrong, or you simply not getting around to it? Just knowing whether it works or not would help me. Thanks.

Im also wondering about this problem, the line should be removed