`pmap` state
yuanqing-wang opened this issue · comments
Yuanqing Wang commented
How should I broadcast a training state to multiple devices and pmap
? I tried to follow this example and had:
import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
from flax import linen as nn
import optax
model = nn.Dense(1)
x = jnp.ones(8)
params = model.init(jax.random.PRNGKey(2666), x)
from flax.training.train_state import TrainState
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
apply_fn=model.apply, params=params, tx=tx,
)
def loss_fn(state, x):
return (model.apply(state.params, x) ** 2.0).mean()
jax.pmap(loss_fn)(state, x)
But got
ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Marc van Zee commented
You should use flax.jax_utils.replicate
. Also, it is safer to use jax.device_count()
rather than hardcoding 8
in your array. Finally, the Dense
layer expects at least two dimensions. This code should work:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState
model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)
def loss_fn(state, x):
return (model.apply(state.params, x) ** 2.0).mean()
jax.pmap(loss_fn)(state, x)