Loop in a pmapped function
AkashGanesan opened this issue · comments
We have the following step function that we are pmapping. It iterates over optimizers specified for different branches of the model and updates the parameter. However, in this the below code, when we are calling _update, _state = _opt.update(_grads, _state, _params)
, we get the following type for _update
E ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, FrozenDict({
E params: { ...
Is the Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
coming from tracing the loop? If so, is it acceptable to index and get the FrozenDict for the update function?
def step(self, batch, params, states):
"""
Compute gradients, loss, accuracy per batch
"""
step = states[0].count
# TODO: correctly get step from the state
grad_fn = jax.value_and_grad(self.loss, has_aux=False)
grad = grad_fn(params, batch)
print("grad is ", grad)
grad = jax.lax.pmean(grad, axis_name="batch")
def update_fn(_opt, _grads, _state, _params):
_update, _state = _opt.update(_grads, _state, _params)
print("update_fn and state", _update[0], _state)
if isinstance(_opt, GradientTransformation):
_params = optax.apply_updates(_params, _update)
elif isinstance(_opt, ParameterTransformation):
_params = _update
return _params, _state
for idx, opt in self.task.optimizers.items():
print("types of state and params", states, type(params))
update, states[idx] = update_fn(opt, grad, states[idx], params)
# TODO: call meter (aux)
return params, states
I don't see something obviously wrong in this snippet. My guess is that _grads, _state, or _params has a different structure than it should have (perhaps wrongly initialized or passed during step).
Btw the canonical way of using different optimizers for subsets of params in optax is by using a multi_transform. This should avoid the complexity of using multiple optimizers in a loop all together