Unpickled modules with constructor arguments cannot be initialized
eseraygun opened this issue · comments
Eser Aygün commented
System information
- OS Platform and Distribution: Linux-6.1.58+-x86_64-with-glibc2.35
- Flax, jax, jaxlib versions: flax==0.8.3, jax==0.4.26, jaxlib==0.4.26+cuda12.cudnn89
- Python version: 3.10.12
- GPU/TPU model and memory: N/A
- CUDA version (if applicable): N/A
Problem you have encountered:
After a cycle of cloudpickle.dumps()
and cloudpickle.loads()
, Flax modules with constructor arguments fail at .init()
with a TypeError
.
What you expected to happen:
Unpickled Flax module to initialize as normal.
Logs, error messages, etc:
TypeError: MyModule.__init__() missing 1 required positional argument: 'arg'
Steps to reproduce:
Colab: https://colab.research.google.com/drive/1Ct4-19Mn-Vexr260511pi1sUblKsRAn6
import cloudpickle
import flax.linen as nn
import jax
class MyModule(nn.Module):
arg: int
@nn.compact
def __call__(self):
return None
UnpickledMyModule = cloudpickle.loads(cloudpickle.dumps(MyModule))
# Fails with: TypeError: MyModule.__init__() missing 1 required positional argument: 'arg'
UnpickledMyModule(arg=1).init(jax.random.key(42))