google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Clarification: how does Llama-2-7b fit on a v4-8 when using Adam?

rodrigo-f-nogueira opened this issue · comments

My understanding is that we need 20 bytes per parameter when using the Adam optimizer with weights and accumulators stored in float32:

weight_dtype: float32
.

Llama-2-7b thus need (at least) 6.7B * 20 = 134 GB of memory. However, we were able train it on v4-8, which has only 128GB.

Am I missing something?
Thanks!

12 bytes per parameter, right? Should just barely fit.

12 * 6.7B = 80.4GB, all is well

Closing

Great, thank you for the answer but, for instance, in NLG 530B paper (https://arxiv.org/pdf/2201.11990), they mention we need 20 bytes per parameter:

Screen Shot 2024-04-24 at 9 14 08 AM

Perhaps in the MaxText implementation we don't need the extra copy of weights and activations in fp16 (2+2 bytes) because we are casting 32->16bit everytime we do matmul?