google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot pickle linen Modules

cgarciae opened this issue · comments

I am using 0.3.4 and I am getting an error when trying to pickle flax modules, specifically Dense seems to be the problem but other might have similar issues.

Problem you have encountered:

from flax import linen
import pickle

with open("model.pkl", "wb") as f:
    model = pickle.dump(linen.Dense(10), f)

Traceback (most recent call last):
File "test.py", line 8, in
model = pickle.dump(linen.Dense(10), f)
AttributeError: Can't pickle local object 'variance_scaling..init'

While the previous is solved with cloudpickle, this other code doesn't work:

import cloudpickle
from flax import linen
import pickle

class IndentityFlax(linen.Module):
    def __call__(self, x):
        return x

with open("mlp.pkl", "wb") as f:
    cloudpickle.dump(IndentityFlax(), f)

Traceback (most recent call last):
File "test.py", line 25, in
cloudpickle.dump(IndentityFlax(), f)
File "/data/cristian/elegy/.venv/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 55, in dump
CloudPickler(
File "/data/cristian/elegy/.venv/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 563, in dump
return Pickler.dump(self, obj)
TypeError: cannot pickle '_thread._local' object

Ok, this exact example seems to work with cloudpickle but I am getting another error serializing a flax.linen.Module object. I will try to get a reproducible example.

So I found a minimal example that doesn't work with cloudpicklewhich seems to be what is affecting me on my actual problem, see updated issue.

commented

(cloud)pickle issues are a little tricky. For some reason cloudpickle tries to serialize Flax internals. I spend some time before looking into it but my main issue with cloudpickle is that there doesn't seem to be a specification of their algorithm and of course the implementation is black-magic Python. I think the minimal thing we need to officially support a library like cloudpickle is a guide that explains what constraints we should adhere to in order to support cloudpickle. Perhaps something like this does exist but I couldn't find anything last time I looked for it.

You could of course also raise an issue with the cloudpickle team to see if this is even expected behavior from their side in the first place.

@jheek You happen to know which Flax internal object its trying to serialize? I am a bit hesitant to ping cloudpickle without a reproducible example that doesn't involve a whole library (flax) as part of it.

If flax users are not using (cloud)pickle, what is the current recommended way to serialize flax models?

commented

Yeah I agree we should try to minimize the repro. I tried out your pickle example and I was able to remove Flax from the equation:

init_fn = jax.nn.initializers.lecun_normal()
with open("model.pkl", "wb") as f:
    model = pickle.dump(init_fn, f)

So here it's really JAX that is producing a partial function that cannot be pickled.

For cloudpickle I'm not so sure what's going on but it essentially finds an internal ThreadLocal object and decides that it needs to serialize it. This I think doesn't make sense. After all only library methods touch this object (which it shouldn't serialize) and the ThreadLocal object is itself is defined top-level in the module so again it shouldn't try to serialize this object. This pattern of having state in a ThreadLocal object is quite common in Python so I think this should really be fixed in cloudpickle but perhaps I'm overlooking some edge case in how we implemented this in Flax.

@jheek thanks for looking into this!

Its very weird as you say since both elegy and haiku Modules use ThreadLocal but only flax is having issues with cloudpickle. I am more interested about cloudpickle than pickle since its generally more robust and pickle doesn't work for haiku and elegy either so its not really an option.

I will send a PR to flax with a test using cloudpickle to make this effort a little more formal and maybe others can try to give it a shot if that is OK with the flax team.

I am curious indeed why other libraries that use ThreadLocal don't have this problem...

commented

@cgarciae I found the issue. cloudpickle will not serialize functions that are part of a library but it does serialize other globals. I guess this is a python limitation (missing qualname I suppose?). We use the threadlocal inside a decorator function which will get serialized. All we have to do is factor out the body of the decorator into a method so cloudpickle doens't serialize its closure variables. I'm working on a PR

I had a similar issue with pickle recently and then simply switched to dill.

from flax import linen
import dill as pickle

with open("model.pkl", "wb") as f:
    model = pickle.dump(linen.Dense(10), f)
import jax
import dill as pickle

init_fn = jax.nn.initializers.lecun_normal()
with open("model.pkl", "wb") as f:
    model = pickle.dump(init_fn, f)

Both of these examples work with Flax 0.3.4 and Jax 0.2.18.

@jheek that is great news!
@matthias-wright I'll give dill a try :)

According to this answer on SO, cloudpickle seems to more aggressively pick up objects from the context than dill, which seems to be hurting in this case.

Just tested dill and it seems to fail for complex cases similar to regular pickle for flax, haiku, and elegy modules.

If flax users are not using (cloud)pickle, what is the current recommended way to serialize flax models?

We recommend that people serialize data, not code. Our native serialization API only deals with model data, not model code.

  • for corporate users the use of libraries like pickle is often not allowed since it can enable arbitrary-code-execution payload vulnerabilities.
  • any pickle-like library leads to extreme brittleness in my experience as even small changes to code can break the import of older pickled save files. So basically this imposes an indirect and ill-specified backwards-compatibility constraint on a library. That's just unworkable inside a company monorepo environment, perhaps bearable with the library versioning available in the open-source world but still a footgun. (So we for instance explicitly do not promise support for old cloudpickle serialized files in newer versions.)
  • the way we restore data from checkpoints lets us verify a "structural match" against the model the code creates vs that saved to disk, giving us an extra check that code and data are in sync.

Pickle was originally intended for RPC, and I know that cloudpickle is needed to support things like python in Spark. For that use-case it's understandable wanting to be able to serialize for distributed engine support... but even there it's better to define remote calls that exchange data rather than serialized code objects.

commented

+1 to the comments by @levskaya

Especially remote code execution is really overlooked. This is of course intended for RPC protocols but we really don't want a protocol with remote code execution as the standard way of distributing Flax models. I'm also curious how the PyTorch worlds thinks about this. The remote code execution problems is not mentioned in the torch.save/load docs.