Exceptions are not pickle-able
sanderland opened this issue · comments
Sander Land commented
Problem you have encountered:
Flax exceptions can not be pickled. This makes it impossible to trace where/why errors are occurring when calling flax via ray.
System information
Name: flax
Version: 0.8.3
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: dm-haiku, fax
---
Name: jax
Version: 0.4.28
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, fax, flax, optax, orbax-checkpoint, rax
---
Name: jaxlib
Version: 0.4.28
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, fax, optax, orbax-checkpoint, rax
What you expected to happen:
Flax Exceptions to support pickle, cf this guide
Logs, error messages, etc:
In production code:
site-packages/ray/exceptions.py", line 49, in from_ray_exception
return pickle.loads(ray_exception.serialized_exception)
TypeError: ScopeParamShapeError.__init__() missing 3 required positional arguments: 'scope_path', 'value_shape', and 'init_shape'
Steps to reproduce:
import flax.linen as nn
from jax import random
from flax.linen.initializers import lecun_normal
from jax import lax
import pickle
class NoBiasDense(nn.Module):
features: int = 8
@nn.compact
def __call__(self, x):
kernel = self.param('kernel',
lecun_normal(),
(x.shape[-1], self.features)) # <--- Exception from flax docs example
y = lax.dot_general(x, kernel,
(((x.ndim - 1,), (0,)), ((), ())))
return y
variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1)))
try:
_ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
except Exception as e:
str = pickle.dumps(e)
obj = pickle.loads(str) # <--- pickle exception