uta-smile / TCL

code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about the MLM masking

longkukuhi opened this issue · comments

Hi,

10% of the time, we replace masked input tokens with random word

    indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced

Here is the code you use to replace a token with a random word. Is it correct to use 0.5 as the parameter here?
Thank you for your answer.

commented

Thanks for your interest in our work.
As shown in this Line, there is ~indices_replaced to make sure only 10% are replaced by random word.
You can reorder this line as:
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & ~indices_replaced & masked_indices
Please let me know if you might have any other questions.