Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Looks like loss is wrong

a1526772 opened this issue · comments

Isn't the similarity mat sharded? So you would gather after image_logits = torch.cat(ims) @ torch.cat(txt).t() (line 66 wrapper.py) not before.

This a super useful question! For future viewers I’ll quote the section of the paper in discussion. “The calculation of embedding similarities was also sharded with individual GPUs computing only the subset of the pairwise similarities necessary for their local batch of embeddings.”

Let’s start with your question first! The matrix itself is not sharded because we use that final matrix to calculate gradients. That matrix itself is super heavy (5GB alone in fp32 for a batch size as described in the paper 🤯) and local batches only need a subset of it for updates. When we calculate the total matrix we only use it for logging purposes under the torch.no_grad flag. If we use a “naive” implementation we would calculate loss and backprop right there on the spot. As stated in that quote, they use sharding to compute the necessary similarities on GPU. This alteration save as A LOT of memory overhead and keeps you from waiting very long.

This paragraph seems to be interpreted differently by me and another paper. From my read, I interpreted as the GPUs did a gather of the logits and then calculated losses. However, a recent paper which trains a more data-efficient CLIP (https://arxiv.org/abs/2104.08945) states that authors did not gather logits across the GPU in the original CLIP paper. But their work does state that the gathering of logits was important to their data-efficient training! I’m actively working to append their implementation to this repository so hopefully it can be used in the future too.