Jax transforms and Flax models cannot be mixed
erfanzar opened this issue · comments
Hello. I'm implementing the Mixtral models with Jax and Flax and there's a problem with the scan function at here
and I get this error Jax transforms and Flax models cannot be mixed.
System information
- OS Platform and Distribution Ubuntu 23.04
Name: flax
Version: 0.7.5
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home/erfan/venv/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: EasyDeL, FJFormer
Name: jax
Version: 0.4.23
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/erfan/venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, numpy, opt-einsum, scipy
Required-by: chex, distrax, EasyDeL, FJFormer, flax, optax, orbax-checkpoint, rlax
Name: jaxlib
Version: 0.4.23
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/erfan/venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, distrax, EasyDeL, FJFormer, optax, orbax-checkpoint, rlax
- Python version: 3.11
- TPU v4
Problem you have encountered:
What you expected to happen:
Logs, error messages, etc:
File "/home/erfan/PycharmProjects/EasyDeL/lib/python/EasyDel/modules/mixtral/modelling_mixtral_flax.py", line 449, in expert_layer_forward
forward_hidden_state = nn.cond(
^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1353, in cond
return lift_direct_transform(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 487, in lift_direct_transform
return decorator_lift_transform(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 426, in wrapped_fn
return trafo_fn(module_scopes, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/linen/transforms.py", line 1298, in _cond_wrapper
return lift.cond(
^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 1085, in cond
return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/lift.py", line 148, in wrapper
scope._validate_trace_level()
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/scope.py", line 545, in _validate_trace_level
tracers.check_trace_level(self.trace_level)
File "/home/erfan/venv/lib/python3.11/site-packages/flax/core/tracers.py", line 36, in check_trace_level
raise errors.JaxTransformError()
flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
Process finished with exit code 1
Steps to reproduce:
code to get this error
pip install git+https://github.com/erfanzar/EasyDeL.git
import copy
import os
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
import jax
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel
from jax import numpy as jnp
from transformers import MixtralForCausalLM
import torch
import numpy as np
def main():
torch.manual_seed(42)
seq_len = 128
config = MixtralConfig(
hidden_size=256,
num_attention_heads=8,
num_hidden_layers=1,
num_key_value_heads=4,
intermediate_size=512,
num_local_experts=8,
max_position_embeddings=seq_len
)
batch_size = len(jax.devices())
torch_model = MixtralForCausalLM(
config=copy.deepcopy(config)
)
params = {"params":
huggingface_to_easydel(
torch_model.state_dict(),
embedding_layer_names=["embed_tokens"],
device=jax.devices("cpu")[0]
)
}
np_random_input_ids = np.random.randint(0, config.vocab_size, (batch_size, seq_len))
input_ids = torch.from_numpy(np_random_input_ids).reshape(batch_size, -1).to(torch.long)
flax_input_ids = jnp.asarray(np_random_input_ids, dtype=jnp.int32).reshape(batch_size, -1)
torch_output = torch_model(
input_ids=input_ids
)
torch_output = torch_output.logits.cpu().detach().numpy()
config.add_jax_args()
config.add_basic_configurations(
use_shard_map=True
)
try:
flax_model = FlaxMixtralForCausalLM(
config=config,
dtype=jnp.float32,
param_dtype=jnp.float32,
_do_init=False,
input_shape=(batch_size, seq_len)
)
flax_output = flax_model(
input_ids=flax_input_ids,
params=params,
)
res = jnp.allclose(torch_output, flax_output.logits, atol=1e-5)
print("Mixtral Huggingface Predictions :\n", torch_output,
"\nEasyDel Predictions: \n", flax_output.logits)
if res: # A Little Bit of humor
print("\033[1;36mTest Passed Unfortunately 🥳")
else:
print("\033[1;31mTest Failed Successfully 🤕")
error = jnp.mean(torch_output - flax_output.logits)
print("Error : ", error)
except TypeError as e:
print(e.__str__())
if __name__ == "__main__":
main()
I have fixed that a 5 or 6 weeks ago, anyway thanks for help!