google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unpickled modules with constructor arguments cannot be initialized

eseraygun opened this issue · comments

System information

  • OS Platform and Distribution: Linux-6.1.58+-x86_64-with-glibc2.35
  • Flax, jax, jaxlib versions: flax==0.8.3, jax==0.4.26, jaxlib==0.4.26+cuda12.cudnn89
  • Python version: 3.10.12
  • GPU/TPU model and memory: N/A
  • CUDA version (if applicable): N/A

Problem you have encountered:

After a cycle of cloudpickle.dumps() and cloudpickle.loads(), Flax modules with constructor arguments fail at .init() with a TypeError.

What you expected to happen:

Unpickled Flax module to initialize as normal.

Logs, error messages, etc:

TypeError: MyModule.__init__() missing 1 required positional argument: 'arg'

Steps to reproduce:

Colab: https://colab.research.google.com/drive/1Ct4-19Mn-Vexr260511pi1sUblKsRAn6

import cloudpickle
import flax.linen as nn
import jax

class MyModule(nn.Module):
  arg: int

  @nn.compact
  def __call__(self):
    return None

UnpickledMyModule = cloudpickle.loads(cloudpickle.dumps(MyModule))
# Fails with: TypeError: MyModule.__init__() missing 1 required positional argument: 'arg'
UnpickledMyModule(arg=1).init(jax.random.key(42))