one little confusion about the loss_fn_kd function
libo-huang opened this issue · comments
Many thanks for your impressive project. Here I am a few confused about the .detach()
in the below code,
Line 35 in a02db26
which is defined in
Line 18 in a02db26
Refer to the blog, PyTorch .detach() method , .detach()
will take the targets_norm
as one fixed parameter in the the KD_loss
, and the backpropagation will not update the parameters along the targets_norm
related branch.
However, in your another project, brain-inspired-replay, the same loss function, loss_fn_kd
uses,
targets_norm = torch.cat([targets_norm, zeros_to_add], dim=1)
as shown in line 29, in which no .detach()
is attached.
Although the same results all these two types I have tested, I am still confused about how does the second type work?
Hi @HLBayes, I'm sorry for the very late reply!
This difference between the two repositories is indeed confusing. In both repositories, the .detach()
operation is actually not needed. As you indicate in your comment, the .detach()
operation stops backpropagation as it resets any gradients being tracked. However, in both repositories, the target_norm
variable already did not have any gradients being tracked, as that variable was computed using with torch.no_grad():
(as for example here: link). Sorry for the confusion!