khurramjaved96 / mrcl

Code for the NeurIPS19 paper "Meta-Learning Representations for Continual Learning"

Home Page:https://arxiv.org/abs/1905.12588

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Bug fix

jlindsey15 opened this issue · comments

Hi! I noticed the code was recently updated to (among other things) fix a bug in the meta-gradient computation. Could you explain what the bug was, and what would be the minimal change needed to fix it, starting from the earlier version of the code? Thanks a bunch!

Hi!

First step is to set create_graph=True in the torch.autograd.grad function. Without this flag, the second order gradients are never computed and the code uses the first-order approximation -- far from ideal.

However, setting create_graph = True in the old code will make the experiments prohibitively slower. This is because the old code was computing gradients of the RLN parameters in every inner step even though the gradient is not needed. To fix that, only compute inner loop gradients for the PLN network.

The two changes boil down to changing:

grad = torch.autograd.grad(loss, fast_weights)

to

torch.autograd.grad(loss, self.net.get_adaptation_parameters(fast_weights), create_graph=True)

where get_adaptation_parameters removes meta-parameters that are fixed in the inner loop.

Computing the correct gradients significantly decreases the time to convergence on the two benchmarks.

Thanks very much!