Is flops calculation correct?
lorabit110 opened this issue · comments
It seems when llm-foundry calculates flops for MPT models, n_heads is not taken into consideration:
llm-foundry/llmfoundry/models/mpt/modeling_mpt.py
Line 1091 in 2634987
Based on PaLM paper, the attention part (attn_flops_per_seq
) should be calculated as 2*2*LHQT^2
(a factor of 3 is removed because it's added here), but in llm-foundry's implementation, it's 2*2*LQT^2
. Why is n_heads
or H
missing from the calculation? Did I miss anything?
Hi, with the variables you are using, Foundry's implementation is not 2*2*LQT^2
, as Q
is head dimension
, but in the foundry implementation its actually d_model
, which is the same as num heads
* head dimension
.
Thanks for the explanation! It makes sense to me now.