google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`pmap` state

yuanqing-wang opened this issue · comments

How should I broadcast a training state to multiple devices and pmap? I tried to follow this example and had:

import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

from flax import linen as nn
import optax
model = nn.Dense(1)
x = jnp.ones(8)
params = model.init(jax.random.PRNGKey(2666), x)
from flax.training.train_state import TrainState
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
    apply_fn=model.apply, params=params, tx=tx,
)
def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()
jax.pmap(loss_fn)(state, x)

But got

ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

You should use flax.jax_utils.replicate. Also, it is safer to use jax.device_count() rather than hardcoding 8 in your array. Finally, the Dense layer expects at least two dimensions. This code should work:

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState

model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
    apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)

def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()

jax.pmap(loss_fn)(state, x)