asyncio error while loading weights
mobley-trent opened this issue · comments
Eddy Oyieko commented
I'm trying to run the example from the docs on saving / loading model weights. My first notebook run worked just fine but subsequent runs failed with the following error:
File /opt/miniconda/envs/multienv/lib/python3.10/asyncio/locks.py:234, in Condition.__init__(self, lock, loop)
...
ValueError: loop argument must agree with lock
Here is my code:
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint
import optax
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))
model = nn.Dense(features=3)
variables = model.init(key2, x1)
tx = optax.sgd(learning_rate=0.001)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)
raw_restored = orbax_checkpointer.restore(CKPT_DIR)
System information
- OS Platform and Distribution: Linux (Github Codespaces)
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
:- flax: 0.8.0
- jax: 0.4.23
- jaxlib: 0.4.23+cuda12.cudnn89
- Python version: 3.10
- GPU/TPU model and memory:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-PCIE-16GB On | 00000001:00:00.0 Off | Off |
| N/A 29C P0 36W / 250W | 12440MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
- CUDA version (if applicable): 12.3
Cristian Garcia commented
Hey! I ran this code locally and it works:
import jax
import numpy as np
import optax
import orbax.checkpoint
from jax import numpy as jnp
from jax import random
from flax import linen as nn
from flax.training import train_state
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))
model = nn.Dense(features=3)
variables = model.init(key2, x1)
tx = optax.sgd(learning_rate=0.001)
state = train_state.TrainState.create(
apply_fn=model.apply, params=variables['params'], tx=tx
)
state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}
from pathlib import Path
from flax.training import orbax_utils
CKPT_DIR = Path('CKPT_DIR').absolute()
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)
raw_restored = orbax_checkpointer.restore(CKPT_DIR)
Are you running this on colab?
Eddy Oyieko commented
No I'm using Codespaces @cgarciae. Your code is bringing up the same error as well
Cristian Garcia commented
I remember that before you used to monkey patch asyncio so orbax worked on jupyter/colab, maybe something similar is happening here? As a note this code runs ok in colab.
Cristian Garcia commented
You should post an issue on the orbax repo.