google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

asyncio error while loading weights

mobley-trent opened this issue · comments

I'm trying to run the example from the docs on saving / loading model weights. My first notebook run worked just fine but subsequent runs failed with the following error:

File /opt/miniconda/envs/multienv/lib/python3.10/asyncio/locks.py:234, in Condition.__init__(self, lock, loop)
...
ValueError: loop argument must agree with lock

Here is my code:

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization

import orbax.checkpoint
import optax

key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      
model = nn.Dense(features=3)
variables = model.init(key2, x1)

tx = optax.sgd(learning_rate=0.001)     
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)

state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)

raw_restored = orbax_checkpointer.restore(CKPT_DIR)

System information

  • OS Platform and Distribution: Linux (Github Codespaces)
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
    • flax: 0.8.0
    • jax: 0.4.23
    • jaxlib: 0.4.23+cuda12.cudnn89
  • Python version: 3.10
  • GPU/TPU model and memory:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla V100-PCIE-16GB           On  | 00000001:00:00.0 Off |                  Off |
| N/A   29C    P0              36W / 250W |  12440MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
  • CUDA version (if applicable): 12.3

Hey! I ran this code locally and it works:

import jax
import numpy as np
import optax
import orbax.checkpoint
from jax import numpy as jnp
from jax import random

from flax import linen as nn
from flax.training import train_state

key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))
model = nn.Dense(features=3)
variables = model.init(key2, x1)

tx = optax.sgd(learning_rate=0.001)
state = train_state.TrainState.create(
  apply_fn=model.apply, params=variables['params'], tx=tx
)

state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))
config = {'dimensions': np.array([5, 3])}
ckpt = {'model': state, 'config': config, 'data': [x1]}

from pathlib import Path

from flax.training import orbax_utils

CKPT_DIR = Path('CKPT_DIR').absolute()

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(CKPT_DIR, ckpt, save_args=save_args)

raw_restored = orbax_checkpointer.restore(CKPT_DIR)

Are you running this on colab?

No I'm using Codespaces @cgarciae. Your code is bringing up the same error as well

I remember that before you used to monkey patch asyncio so orbax worked on jupyter/colab, maybe something similar is happening here? As a note this code runs ok in colab.

You should post an issue on the orbax repo.