Bug fix
jlindsey15 opened this issue · comments
Hi! I noticed the code was recently updated to (among other things) fix a bug in the meta-gradient computation. Could you explain what the bug was, and what would be the minimal change needed to fix it, starting from the earlier version of the code? Thanks a bunch!
Hi!
First step is to set create_graph=True in the torch.autograd.grad function. Without this flag, the second order gradients are never computed and the code uses the first-order approximation -- far from ideal.
However, setting create_graph = True in the old code will make the experiments prohibitively slower. This is because the old code was computing gradients of the RLN parameters in every inner step even though the gradient is not needed. To fix that, only compute inner loop gradients for the PLN network.
The two changes boil down to changing:
grad = torch.autograd.grad(loss, fast_weights)
to
torch.autograd.grad(loss, self.net.get_adaptation_parameters(fast_weights), create_graph=True)
where get_adaptation_parameters removes meta-parameters that are fixed in the inner loop.
Computing the correct gradients significantly decreases the time to convergence on the two benchmarks.
Thanks very much!