Sense-GVT / DeCLIP

Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

What is AllGather for. Why use ALLGather.

lyccol opened this issue · comments

if self.training and self.use_allgather or all_gather:
gathered_image_features = self.all_gather(image_features)
gathered_text_features = self.all_gather(text_features)
logits_per_image = logit_scale * image_features @ gathered_text_features.t()
logits_per_text = logit_scale * text_features @ gathered_image_features.t()
else:
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
return logits_per_image, logits_per_text

Since CLIP (Contrastive Image Language Pre-training) requires a large batch size, we use all-gather during DDP (Distributed Data-Parallel Acceleration) to scale up the batch size by synchronising data from multiple cards.
For example, when we use 128 cards with a batch size of 256 on each card, the dimension of image_features (\resp, text_features) per thread is [256, feature_dim], while the dimension of gathered_image_features (\resp, gathered_text_features) is [128*256, feature_dim], so after the gradient synchronisation of the loss function, it is equivalent to directly using the batch size of 32768.
The same reason holds for SLIP, FILIP, DeCLIP, DeFILIP