google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] Asyncio error while loading Flax weights

mobley-trent opened this issue · comments

I'm trying to run the example from the Flax docs on saving / loading model weights. My first notebook run worked just fine but subsequent runs failed with the following error:

File /opt/miniconda/envs/multienv/lib/python3.10/asyncio/locks.py:234, in Condition.__init__(self, lock, loop)
...
ValueError: loop argument must agree with lock

Here is the code:

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization

import orbax.checkpoint
import optax

key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      
model = nn.Dense(features=3)
variables = model.init(key2, x1)

tx = optax.sgd(learning_rate=0.001)     
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)

state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)

raw_restored = orbax_checkpointer.restore(CKPT_DIR)

System Information:

  • OS : Linux (Github Codespaces)
  • Python 3.10
  • Package info:
    • flax : 0.8.0
    • jax : 0.4.23
    • jaxlib : 0.4.23+cuda12.cudnn89
    • orbax : 0.1.9
  • GPU info:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla V100-PCIE-16GB           On  | 00000001:00:00.0 Off |                  Off |
| N/A   29C    P0              36W / 250W |  12440MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

This issue is also open on the Flax repo : google/flax#3679

Orbax had multiple fixes related to asyncio after your version, 0.1.9. Can you please try the more recent versions?

@mobley-trent Did you find a solution to this? I am facing the same error, using

diffrax                   0.5.0                    pypi_0    pypi
flax                      0.8.0                    pypi_0    pypi
jax                       0.4.23                   pypi_0    pypi
jaxlib                    0.4.23+cuda12.cudnn89          pypi_0    pypi
jaxopt                    0.8.3                    pypi_0    pypi
jaxtyping                 0.2.25                   pypi_0    pypi
lineax                    0.0.4                    pypi_0    pypi
optax                     0.1.9                    pypi_0    pypi
orbax-checkpoint          0.5.3                    pypi_0    pypi

on

Operating System: Rocky Linux 8.9 (Green Obsidian)
CPE OS Name: cpe:/o:rocky:rocky:8:GA
Kernel: Linux 4.18.0-513.11.1.el8_9.x86_64
Architecture: x86-64

Here is a minimal example, copied & pasted from the orbax documentation :

import numpy as np
import orbax.checkpoint as ocp
path = ocp.test_utils.create_empty('/Users/soeren.becker/my-checkpoints/')
my_tree = {
    'a': np.arange(8),
    'b': {
        'c': 42,
        'd': np.arange(16),
    },
}
checkpointer = ocp.StandardCheckpointer()
# 'checkpoint_name' must not already exist.
checkpointer.save(path / 'checkpoint_name', my_tree)
checkpointer.restore(path / 'checkpoint_name/')

It fails with the same error described above:

Traceback (most recent call last):
  File "/Users/soeren.becker/test_orbax.py", line 17, in <module>
    checkpointer.restore(path / 'checkpoint_name/')
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 166, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py", line 165, in restore
    return super().restore(
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1031, in restore
    byte_limiter = get_byte_limiter(self._concurrent_gb)
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 170, in get_byte_limiter
    return asyncio.run(_create_byte_limiter())
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/base_events.py", line 641, in run_until_complete
    return future.result()
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 168, in _create_byte_limiter
    return LimitInFlightBytes(concurrent_bytes)  # pylint: disable=protected-access
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 144, in __init__
    self._cv = asyncio.Condition(lock=asyncio.Lock())
  File "/usr/local/Caskroom/miniforge/base/envs/test_orbax/lib/python3.10/asyncio/locks.py", line 234, in __init__
    raise ValueError("loop argument must agree with lock")
ValueError: loop argument must agree with lock

I tried this on linux (see comment above) and mac os (11.5.1, M1) with a fresh conda environment:

(test_orbax) ➜  ~ conda list
# packages in environment at /usr/local/Caskroom/miniforge/base/envs/test_orbax:
#
# Name                    Version                   Build  Channel
absl-py                   2.1.0                    pypi_0    pypi
bzip2                     1.0.8                h93a5062_5    conda-forge
ca-certificates           2024.2.2             hf0a4a13_0    conda-forge
etils                     1.6.0                    pypi_0    pypi
fsspec                    2024.2.0                 pypi_0    pypi
importlib-resources       6.1.1                    pypi_0    pypi
jax                       0.4.24                   pypi_0    pypi
jaxlib                    0.4.24                   pypi_0    pypi
libblas                   3.9.0           21_osxarm64_openblas    conda-forge
libcblas                  3.9.0           21_osxarm64_openblas    conda-forge
libcxx                    16.0.6               h4653b0c_0    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgfortran               5.0.0           13_2_0_hd922786_3    conda-forge
libgfortran5              13.2.0               hf226fd6_3    conda-forge
liblapack                 3.9.0           21_osxarm64_openblas    conda-forge
libopenblas               0.3.26          openmp_h6c19121_0    conda-forge
libsqlite                 3.44.2               h091b4b1_0    conda-forge
libzlib                   1.2.13               h53f4e23_5    conda-forge
llvm-openmp               17.0.6               hcd81f8e_0    conda-forge
ml-dtypes                 0.3.2                    pypi_0    pypi
msgpack                   1.0.7                    pypi_0    pypi
ncurses                   6.4                  h463b476_2    conda-forge
nest-asyncio              1.6.0                    pypi_0    pypi
numpy                     1.26.4          py310hd45542a_0    conda-forge
openssl                   3.2.1                h0d3ecfb_0    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
orbax-checkpoint          0.5.3                    pypi_0    pypi
pip                       24.0               pyhd8ed1ab_0    conda-forge
protobuf                  4.25.2                   pypi_0    pypi
python                    3.10.0          h43b31ca_3_cpython    conda-forge
python_abi                3.10                    4_cp310    conda-forge
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h92ec313_1    conda-forge
scipy                     1.12.0          py310hf4b343e_2    conda-forge
setuptools                69.0.3             pyhd8ed1ab_0    conda-forge
sqlite                    3.44.2               hf2abe2d_0    conda-forge
tensorstore               0.1.53                   pypi_0    pypi
tk                        8.6.13               h5083fa2_1    conda-forge
typing-extensions         4.9.0                    pypi_0    pypi
tzdata                    2024a                h0c530f3_0    conda-forge
wheel                     0.42.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h57fd34a_0    conda-forge
zipp                      3.17.0                   pypi_0    pypi

Any help to get the checkpointing working would be greatly appreciated 👍

Hello @niketkumar I upgraded orbax-checkpoint to 0.5.3, but the same error persists

Thanks for sharing the error stack and updates!

It seems you are using python3.10. But the standard cpython 3.10 got a fix which removed this error check in Oct-2021. Please see https://bugs.python.org/issue45416. Are you using cpython flavor?

Are you seeing this error in Github Codespaces? I double checked it in Google Colab and it worked.

In my understanding it is a python flavor or python version issue. Can you please try python 3.11?

@niketkumar thank you for your help - indeed updating to python 3.11 resolved the issue for me on both linux and mac os. Thank you!