google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Loop in a pmapped function

AkashGanesan opened this issue · comments

We have the following step function that we are pmapping. It iterates over optimizers specified for different branches of the model and updates the parameter. However, in this the below code, when we are calling _update, _state = _opt.update(_grads, _state, _params), we get the following type for _update

E     ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, FrozenDict({
E         params: { ... 

Is the Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)> coming from tracing the loop? If so, is it acceptable to index and get the FrozenDict for the update function?

    def step(self, batch, params, states):
        """
        Compute gradients, loss, accuracy per batch
        """
        step = states[0].count
        # TODO: correctly get step from the state
        grad_fn = jax.value_and_grad(self.loss, has_aux=False)
        grad = grad_fn(params, batch)
        print("grad is ", grad)
        grad = jax.lax.pmean(grad, axis_name="batch")

        def update_fn(_opt, _grads, _state, _params):
            _update, _state = _opt.update(_grads, _state, _params)
            print("update_fn and state", _update[0], _state)
            if isinstance(_opt, GradientTransformation):
                _params = optax.apply_updates(_params, _update)
            elif isinstance(_opt, ParameterTransformation):
                _params = _update
            return _params, _state

        for idx, opt in self.task.optimizers.items():
            print("types of state and params", states, type(params))
            update, states[idx] = update_fn(opt, grad, states[idx], params)
        # TODO: call meter (aux)
        return params, states
commented

I don't see something obviously wrong in this snippet. My guess is that _grads, _state, or _params has a different structure than it should have (perhaps wrongly initialized or passed during step).

Btw the canonical way of using different optimizers for subsets of params in optax is by using a multi_transform. This should avoid the complexity of using multiple optimizers in a loop all together