mosaicml / llm-foundry

LLM training code for Databricks foundation models

Home Page:https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is flops calculation correct?

lorabit110 opened this issue · comments

It seems when llm-foundry calculates flops for MPT models, n_heads is not taken into consideration:

def flops_per_batch(self, batch: Mapping) -> int:

Based on PaLM paper, the attention part (attn_flops_per_seq) should be calculated as 2*2*LHQT^2 (a factor of 3 is removed because it's added here), but in llm-foundry's implementation, it's 2*2*LQT^2. Why is n_heads or H missing from the calculation? Did I miss anything?

image

Hi, with the variables you are using, Foundry's implementation is not 2*2*LQT^2, as Q is head dimension, but in the foundry implementation its actually d_model, which is the same as num heads * head dimension.

Thanks for the explanation! It makes sense to me now.