karpathy / minGPT

A minimal PyTorch re-implementation of the OpenAI GPT (Generative Pretrained Transformer) training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Perfect training and evaluation loss, but terrible test-time performance

micahcarroll opened this issue · comments

I encountered a pretty ridiculous failure mode in which:

  • I was getting almost 0 training and validation loss
  • I was getting very bad performance when feeding the model incomplete sequences (e.g. for test-time word-generation)

After much debugging, I found that the issue was that the value of self.masked_bias (currently set to -10e4) is too low – this high value is supposed to implement the "mask" of the causal masked attention.

For some high enough learning rates, the network is able to find a hack to copy the input to the output (getting around the causal masking): just drive the attention weights lower than -10e-4, and the causal mask will effectively not be doing anything! This means that the model will be able to attend the whole input in trying to generate the output (even the future tokens), so it will be able to simply copy it at training and validation time (getting 0 loss).

What I found while debugging this line:

>>> attn_weights[0,0,:4,:4]
[[-15258.9805, -10000.0000, -10000.0000, -10000.0000],
        [-15044.7910, -16940.4766, -10000.0000, -10000.0000],
        [-11722.1553, -13301.4287,  -1438.0649, -10000.0000],
        [ -9711.6445, -11315.6006,  -1065.3066, -12052.6035]]

As one can see, the attention weights outside of the causal masking are even larger than this fixed value, which lead the softmax to return non-zero weights on all inputs.

Documenting this in case it could be useful to others!

In terms of fixes, there should probably be an assertion error that the weights never go below self.masked_bias, or setting it to an even lower value.

(I encountered this issue when using this code from the transformers library, but I'm guessing it also affects this library)

Actually, realized only the transformers library suffers from this. The equivalent code here doesn't have the same issue.