google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`DynamicScale` behaves unexpected when computing per-sample gradients with `vmap`.

hlzl opened this issue · comments

When running jax.vmap, e.g. to compute per-sample gradients, the fin_steps and scale attributes of DynamicScale might become arrays, leading to an error in the next step during training if not handled manually. The thrown TypeError does not directly hint at the actual problem of a non-scalar scale attribute.

System information

  • jax==0.4.28 and flax==0.8.5

Problem you have encountered:

Due to self.scale becoming an array in the output of the first vmap call, the loss_wrapper also starts to return an array instead of a scalar inside of DynamicScale.

What you expected to happen:

The scale and fin_steps attributes should be either averaged or enforced to be scalars and thus not cause the TypeError.

Logs, error messages, etc:

File ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132, in DynamicScale.value_and_grad.<locals>.grad_fn_wrapper(*args)
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:131) def grad_fn_wrapper(*args):
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132)   aux, grad = grad_fn(*args)
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:133)   aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:135)   grad = jax.tree_util.tree_map(
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:136)     lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad
    ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:137)   )
TypeError: Gradient only defined for scalar-output functions. Output had shape: (32,).

Steps to reproduce:

from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

from flax.training import dynamic_scale


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x


def cross_entropy_loss(params, model, image, label):
    """Loss function used for training."""
    logits = model.apply({"params": params}, image)
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
    return loss, logits

model = MLP([12, 8, 4])
input = jnp.ones((32, 10))
labels = jnp.ones((32,), dtype=int)
variables = model.init(jax.random.key(0), input)
output = model.apply(variables, input)

ds = dynamic_scale.DynamicScale()

# 1st batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
    ds.value_and_grad(cross_entropy_loss, has_aux=True),
    in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)

# 2nd batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
    ds.value_and_grad(cross_entropy_loss, has_aux=True),
    in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)

Can be fixed manually with ds = ds.replace(fin_steps=ds.fin_steps.mean(), scale=ds.scale.mean()) after each step.
Should be handled automatically / enforced within DynamicScale IMO.