luyug / Condenser

EMNLP 2021 - Pre-training architectures for dense retrieval

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The relative weight of the MLM loss compared to the contrastive loss

hyleemindslab opened this issue · comments

In the paper, Equation 7 indicates that both the MLM and contrastive losses are divided by the effective batch size, whose value would be equal to 2 * per_device_train_batch_size * world_size. But the MLM loss calculation code seems to divide the MLM loss by per_device_train_batch_size * world_size (line 227), since the CoCondenserDataset's __getitem__ method returns two spans belonging to the same document, thereby making the actual batch dimension larger by a factor of 2.

I feel like I am missing something. Could you please help me out?

Condenser/modeling.py

Lines 219 to 230 in de9c257

loss = self.mlm_loss(hiddens, labels)
if self.model_args.late_mlm:
loss += lm_out.loss
if grad_cache is None:
co_loss = self.compute_contrastive_loss(co_cls_hiddens)
return loss + co_loss
else:
loss = loss * (float(hiddens.size(0)) / self.train_args.per_device_train_batch_size)
cached_grads = grad_cache[chunk_offset: chunk_offset + co_cls_hiddens.size(0)]
surrogate = torch.dot(cached_grads.flatten(), co_cls_hiddens.flatten())
return loss, surrogate

Condenser/data.py

Lines 177 to 179 in de9c257

def __getitem__(self, item):
spans = self.dataset[item]['spans']
return random.sample(spans, 2)

Line 227 is for gradient accumulation scaling, not averaging across batch examples, check out the trainer code

Condenser/trainer.py

Lines 161 to 185 in de9c257

for local_chunk_id, chunk in enumerate(chunked_inputs):
device_offset = max(0, self.args.local_rank) * self.args.per_device_train_batch_size * 2
local_offset = local_chunk_id * self.args.cache_chunk_size
chunk_offset = device_offset + local_offset
with rnd_states[local_chunk_id]:
if self.use_amp:
with autocast():
lm_loss, surrogate = self.compute_loss(model, chunk, grad_cache, chunk_offset)
else:
lm_loss, surrogate = self.compute_loss(model, chunk, grad_cache, chunk_offset)
if self.args.gradient_accumulation_steps > 1:
raise ValueError
ddp_no_sync = self.args.local_rank > -1 and (local_chunk_id + 1 < len(chunked_inputs))
with model.no_sync() if ddp_no_sync else nullcontext():
if self.use_amp:
(self.scaler.scale(lm_loss) + surrogate).backward()
elif self.use_apex:
raise ValueError
elif self.deepspeed:
raise ValueError
else:
(lm_loss + surrogate).backward()
total_loss += lm_loss

Yes, I just expected the MLM loss for a sub-batch to be scaled by (# of spans in the sub-batch / # of spans in the local batch) so that the final gradient is w.r.t. the loss that is averaged across the spans in the batch, which I thought would be written as loss = loss * (float(hiddens.size(0)) / (2 * self.train_args.per_device_train_batch_size)). But I'm starting to think it may not be that important.

Right, there's a factor of 2. We didn't actually experiment a lot with how to interpolate; the current code seems to work fine. As training progress, with momentum stabilizing in the optimizer, I also expect that it won't be super important.