[Bug] Asyncio error while loading Flax weights
mobley-trent opened this issue · comments
I'm trying to run the example from the Flax docs on saving / loading model weights. My first notebook run worked just fine but subsequent runs failed with the following error:
File /opt/miniconda/envs/multienv/lib/python3.10/asyncio/locks.py:234, in Condition.__init__(self, lock, loop)
...
ValueError: loop argument must agree with lock
Here is the code:
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint
import optax
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))
model = nn.Dense(features=3)
variables = model.init(key2, x1)
tx = optax.sgd(learning_rate=0.001)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)
raw_restored = orbax_checkpointer.restore(CKPT_DIR)
System Information:
- OS : Linux (Github Codespaces)
- Python 3.10
- Package info:
- flax : 0.8.0
- jax : 0.4.23
- jaxlib : 0.4.23+cuda12.cudnn89
- orbax : 0.1.9
- GPU info:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-PCIE-16GB On | 00000001:00:00.0 Off | Off |
| N/A 29C P0 36W / 250W | 12440MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
This issue is also open on the Flax repo : google/flax#3679
Orbax had multiple fixes related to asyncio after your version, 0.1.9
. Can you please try the more recent versions?
@mobley-trent Did you find a solution to this? I am facing the same error, using
diffrax 0.5.0 pypi_0 pypi
flax 0.8.0 pypi_0 pypi
jax 0.4.23 pypi_0 pypi
jaxlib 0.4.23+cuda12.cudnn89 pypi_0 pypi
jaxopt 0.8.3 pypi_0 pypi
jaxtyping 0.2.25 pypi_0 pypi
lineax 0.0.4 pypi_0 pypi
optax 0.1.9 pypi_0 pypi
orbax-checkpoint 0.5.3 pypi_0 pypi
on
Operating System: Rocky Linux 8.9 (Green Obsidian)
CPE OS Name: cpe:/o:rocky:rocky:8:GA
Kernel: Linux 4.18.0-513.11.1.el8_9.x86_64
Architecture: x86-64
Here is a minimal example, copied & pasted from the orbax documentation :
import numpy as np
import orbax.checkpoint as ocp
path = ocp.test_utils.create_empty('/Users/soeren.becker/my-checkpoints/')
my_tree = {
'a': np.arange(8),
'b': {
'c': 42,
'd': np.arange(16),
},
}
checkpointer = ocp.StandardCheckpointer()
# 'checkpoint_name' must not already exist.
checkpointer.save(path / 'checkpoint_name', my_tree)
checkpointer.restore(path / 'checkpoint_name/')
It fails with the same error described above:
Traceback (most recent call last):
File "/Users/soeren.becker/test_orbax.py", line 17, in <module>
checkpointer.restore(path / 'checkpoint_name/')
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 166, in restore
restored = self._handler.restore(directory, args=ckpt_args)
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py", line 165, in restore
return super().restore(
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1031, in restore
byte_limiter = get_byte_limiter(self._concurrent_gb)
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 170, in get_byte_limiter
return asyncio.run(_create_byte_limiter())
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/base_events.py", line 641, in run_until_complete
return future.result()
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 168, in _create_byte_limiter
return LimitInFlightBytes(concurrent_bytes) # pylint: disable=protected-access
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 144, in __init__
self._cv = asyncio.Condition(lock=asyncio.Lock())
File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/locks.py", line 234, in __init__
raise ValueError("loop argument must agree with lock")
ValueError: loop argument must agree with lock
I tried this on linux (see comment above) and mac os (11.5.1, M1) with a fresh conda environment:
(test_orbax) ➜ ~ conda list
# packages in environment at /usr/local/Caskroom/miniforge/base/envs/test_orbax:
#
# Name Version Build Channel
absl-py 2.1.0 pypi_0 pypi
bzip2 1.0.8 h93a5062_5 conda-forge
ca-certificates 2024.2.2 hf0a4a13_0 conda-forge
etils 1.6.0 pypi_0 pypi
fsspec 2024.2.0 pypi_0 pypi
importlib-resources 6.1.1 pypi_0 pypi
jax 0.4.24 pypi_0 pypi
jaxlib 0.4.24 pypi_0 pypi
libblas 3.9.0 21_osxarm64_openblas conda-forge
libcblas 3.9.0 21_osxarm64_openblas conda-forge
libcxx 16.0.6 h4653b0c_0 conda-forge
libffi 3.4.2 h3422bc3_5 conda-forge
libgfortran 5.0.0 13_2_0_hd922786_3 conda-forge
libgfortran5 13.2.0 hf226fd6_3 conda-forge
liblapack 3.9.0 21_osxarm64_openblas conda-forge
libopenblas 0.3.26 openmp_h6c19121_0 conda-forge
libsqlite 3.44.2 h091b4b1_0 conda-forge
libzlib 1.2.13 h53f4e23_5 conda-forge
llvm-openmp 17.0.6 hcd81f8e_0 conda-forge
ml-dtypes 0.3.2 pypi_0 pypi
msgpack 1.0.7 pypi_0 pypi
ncurses 6.4 h463b476_2 conda-forge
nest-asyncio 1.6.0 pypi_0 pypi
numpy 1.26.4 py310hd45542a_0 conda-forge
openssl 3.2.1 h0d3ecfb_0 conda-forge
opt-einsum 3.3.0 pypi_0 pypi
orbax-checkpoint 0.5.3 pypi_0 pypi
pip 24.0 pyhd8ed1ab_0 conda-forge
protobuf 4.25.2 pypi_0 pypi
python 3.10.0 h43b31ca_3_cpython conda-forge
python_abi 3.10 4_cp310 conda-forge
pyyaml 6.0.1 pypi_0 pypi
readline 8.2 h92ec313_1 conda-forge
scipy 1.12.0 py310hf4b343e_2 conda-forge
setuptools 69.0.3 pyhd8ed1ab_0 conda-forge
sqlite 3.44.2 hf2abe2d_0 conda-forge
tensorstore 0.1.53 pypi_0 pypi
tk 8.6.13 h5083fa2_1 conda-forge
typing-extensions 4.9.0 pypi_0 pypi
tzdata 2024a h0c530f3_0 conda-forge
wheel 0.42.0 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h57fd34a_0 conda-forge
zipp 3.17.0 pypi_0 pypi
Any help to get the checkpointing working would be greatly appreciated 👍
Hello @niketkumar I upgraded orbax-checkpoint
to 0.5.3, but the same error persists
Thanks for sharing the error stack and updates!
It seems you are using python3.10
. But the standard cpython 3.10
got a fix which removed this error check in Oct-2021. Please see https://bugs.python.org/issue45416. Are you using cpython
flavor?
Are you seeing this error in Github Codespaces? I double checked it in Google Colab and it worked.
In my understanding it is a python flavor or python version issue. Can you please try python 3.11?
@niketkumar thank you for your help - indeed updating to python 3.11 resolved the issue for me on both linux and mac os. Thank you!