Is softmax scaling optional?
turboderp opened this issue · comments
Can you elaborate on the significance of the softmax scaling? I can't find it referenced in the paper, and it seems to be applied differently for each of the three attention methods in the HF implementation:
- Eager attention applies it whenever the dtype isn't FP32 (since
scale_attention_softmax_in_fp32
,attention_softmax_in_fp32
andscale_attn_weights
are all set. - SDPA sets a scale of
None
, though seems prepared to change it to 1 ifscale_attn_weights
were unset. (?) - The flash-attn module has provisions for applying the scale in
_flash_attention_forward
, but that argument isn't passed so it defaults to None.
Presumably the models are trained with flash-attn so is this just not actually relevant?
None in SDPA or Flash attention is same as 1 / sqrt(d)
scale_attn_weights
is the parameter used to decide whether to use 1 / sqrt(d) or 1 otherwise.
the fp32 arguments is just for stability during training and shouldn't be needed at inference honestly.
Thank you. 👍