google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Exceptions are not pickle-able

sanderland opened this issue · comments

Problem you have encountered:

Flax exceptions can not be pickled. This makes it impossible to trace where/why errors are occurring when calling flax via ray.

System information

Name: flax
Version: 0.8.3
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: dm-haiku, fax
---
Name: jax
Version: 0.4.28
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, fax, flax, optax, orbax-checkpoint, rax
---
Name: jaxlib
Version: 0.4.28
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sander_cohere_com/miniconda3/envs/ct/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, fax, optax, orbax-checkpoint, rax

What you expected to happen:

Flax Exceptions to support pickle, cf this guide

Logs, error messages, etc:

In production code:

site-packages/ray/exceptions.py", line 49, in from_ray_exception
    return pickle.loads(ray_exception.serialized_exception)
TypeError: ScopeParamShapeError.__init__() missing 3 required positional arguments: 'scope_path', 'value_shape', and 'init_shape'

Steps to reproduce:

import flax.linen as nn
from jax import random
from flax.linen.initializers import lecun_normal
from jax import lax
import pickle

class NoBiasDense(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x):
    kernel = self.param('kernel',
                        lecun_normal(),
                        (x.shape[-1], self.features))  # <--- Exception from flax docs example
    y = lax.dot_general(x, kernel,
                        (((x.ndim - 1,), (0,)), ((), ())))
    return y

variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1)))
try:
    _ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
except Exception as e:
    str = pickle.dumps(e)
    obj = pickle.loads(str) # <--- pickle exception