deepglint / unicom

[ICLR 2023] Unicom: Universal and Compact Representation Learning for Image Retrieval

Home Page:https://arxiv.org/pdf/2304.05884.pdf

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Bug in CombinedMarginLoss implementation

MaroonAmor opened this issue · comments

Hi @anxiangsir,

Thanks for sharing your work.

I have a question about the forward pass in CombinedMarginLoss when running sop_vit_b_16.sh as an example. In this case, self.m1 = 1.0, self.m2 = 0.25, and self.m3 = 0.0, But I think with torch.no_grad(), the gradients won't be propagated correctly, right?

It also seems that the implementation of CombinedMarginLoss is adapted from the insightface repo, and its previous version (without torch.no_grad()) makes more sense here: deepinsight/insightface@657ae30

Some issues raised for the same query: deepinsight/insightface#2218, deepinsight/insightface#2255, deepinsight/insightface#2309

Why do we need torch.no_grad() here?

Here we mainly adopted the implementation method of opensphere, and we found that this implementation method makes arcface more stable when training ViT.

@anxiangsir Thanks for getting back to me.

But it is not technically correct, right? The gradients won't be propagated back through those lines under torch.no_grad() (e.g., logits.arccos_()).

Also, I did a comparison experiment (w/ torch.no_grad() vs. w/o torch.no_grad() ) by running it on the SOP dataset using an A100 GPU. The performance w/o torch.no_grad() actually was better.

Any theory or math to support this change to add torch.no_grad()? This really confused me for a while. Thanks.