Clarification: how does Llama-2-7b fit on a v4-8 when using Adam?
rodrigo-f-nogueira opened this issue · comments
My understanding is that we need 20 bytes per parameter when using the Adam optimizer with weights and accumulators stored in float32:
maxtext/MaxText/configs/base.yml
Line 72 in f52e6f7
Llama-2-7b thus need (at least) 6.7B * 20 = 134 GB of memory. However, we were able train it on v4-8, which has only 128GB.
Am I missing something?
Thanks!
12 bytes per parameter, right? Should just barely fit.
12 * 6.7B = 80.4GB, all is well
Closing
Great, thank you for the answer but, for instance, in NLG 530B paper (https://arxiv.org/pdf/2201.11990), they mention we need 20 bytes per parameter:
Perhaps in the MaxText implementation we don't need the extra copy of weights and activations in fp16 (2+2 bytes) because we are casting 32->16bit everytime we do matmul?