jerryji1993 / DNABERT

DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome

Home Page:https://doi.org/10.1093/bioinformatics/btab083

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Your masking of 6 consecutive tokens lets MLM pre-training trivially learn from the adjacent tokens

rcasero opened this issue · comments

In the paper, you note "In the pre-training step, we masked contiguous k-length spans of certain k-mers, considering a token could be trivially inferred from the immediately surrounding k-mers (total ∼15% of input sequence)".

However, tI think that he way it's implemented in mask_tokens in run_pretrain.py would allow to trivially learn the masked tokens from the adjacent ones.

Here's an example:

Let your sequence have length 13, so the indices are 0, 1, ..., 12. The specific bases {A,C,G,T} don't matter for this example. Your 6-mer representation would be (in terms of indices)

0, 1, 2, 3, 4, 5
1, 2, 3, 4, 5, 6
2, 3, 4, 5, 6, 7
3, 4, 5, 6, 7, 8
4, 5, 6, 7, 8, 9
5, 6, 7, 8, 9, 10
6, 7, 8, 9, 10, 11
7, 8, 9, 10, 11, 12

The way you have implemented masking is that you'd chose any token with probability 2.5%, and then mask that token, the two previous ones, and the three following ones, so that 2.5% * 6 = 15% masking probability.

And you compute the loss from the masked tokens.

However, if we mask the 6 tokens in the middle in the example above

0, 1, 2, 3, 4, 5
[MASK]
[MASK]
[MASK]  <- token chosen with probability 2.5%
[MASK]
[MASK]
[MASK]
7, 8, 9, 10, 11, 12

you would still be able to trivially reconstruct most of them by the two adjacent unmasked tokens.