google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DEFAULT_MASK_VALUE causes gradient explosion and nan loss on deep models

logicchains opened this issue · comments

I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.

The gradient explosion seemed to be coming from
local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.

Changing

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
to
DEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.

Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.

@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.

@anfals please be aware of this as you do convergence testing on GPU