google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: Gradient Accumulation

thiagolaitz opened this issue · comments

Hello, does it support gradient accumulation or microbatches like those in the T5X repository? I didn't find a parameter for this in base.yml, maybe I just didn't see it? Thank you!

We don't support that out of the box. We've found that tuning LR to be smaller is a better approach.

What is your use case?

I'm training bigger models than before, so I can't use the same batch size on the same TPU. Got any recommended ablation studies on using gradient accumulation versus lowering the LR? Also, if I skip gradient accumulation, should I just linearly reduce the LR based on the batch size? Thanks!

+1
Adding another use case: considering that the availability of TPUs vary, we encounter situations where we initially train a model with a v4-128 TPU but later need to replicate the experiment with a v4-64 TPU, which has less memory. Thus, we must use gradient accumulation to maintain consistency in the results.

Simply add following code after allocation of optimizer in optimizers.py support the gradient accumulation:

if config.accumulate_gradient_steps > 1:
    optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)