GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

one little confusion about the loss_fn_kd function

libo-huang opened this issue · comments

Many thanks for your impressive project. Here I am a few confused about the .detach() in the below code,

targets_norm = torch.cat([targets_norm.detach(), zeros_to_add], dim=1)

which is defined in

def loss_fn_kd(scores, target_scores, T=2.):

Refer to the blog, PyTorch .detach() method , .detach() will take the targets_norm as one fixed parameter in the the KD_loss, and the backpropagation will not update the parameters along the targets_norm related branch.

However, in your another project, brain-inspired-replay, the same loss function, loss_fn_kd uses,

 targets_norm = torch.cat([targets_norm, zeros_to_add], dim=1)

as shown in line 29, in which no .detach() is attached.

Although the same results all these two types I have tested, I am still confused about how does the second type work?

Hi @HLBayes, I'm sorry for the very late reply!
This difference between the two repositories is indeed confusing. In both repositories, the .detach() operation is actually not needed. As you indicate in your comment, the .detach() operation stops backpropagation as it resets any gradients being tracked. However, in both repositories, the target_norm variable already did not have any gradients being tracked, as that variable was computed using with torch.no_grad(): (as for example here: link). Sorry for the confusion!