`save_checkpoint` fails with the most recent orbax release
apaleyes opened this issue · comments
Hello, flax team!
In past few days we observed a call to save_checkpoint
failing with the most recent orbax release (0.5.17). When downgrading to orbax-checkpoint==0.5.16
everything works again.
The example to reproduce can be obtained from flax docs. For convenience it's copied below.
With orbax-checkpoint==0.5.17
this code fails with an exception. With orbax-checkpoint==0.5.16
it works.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): seen on both Ubuntu and Mac
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: flax 0.8.4, jax 0.4.30, jaxlib 0.4.30 - Python version: 3.11
Logs, error messages, etc:
Traceback (most recent call last):
File "/tmp/min_example.py", line 46, in <module>
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
File "/tmp/.venv/lib/python3.11/site-packages/flax/training/checkpoints.py", line 697, in save_checkpoint
orbax_checkpointer.save(
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
tmpdir = utils.create_tmp_directory(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/path/step.py", line 607, in create_tmp_directory
if multihost.is_primary_host(primary_host):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 246, in is_primary_host
if primary_host is None or primary_host == process_index():
^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 252, in process_index
if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 1426, in value
val = getattr(self._flagvalues, self._name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
raise _exceptions.UnparsedFlagAccessError(
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.
Steps to reproduce:
import os
from typing import Optional, Any
import shutil
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import optax
ckpt_dir = '/tmp/flax_ckpt'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,)) # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}
# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
# Import Flax Checkpoints.
from flax.training import checkpoints
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
target=ckpt,
step=0,
overwrite=True,
keep=2)
They pushed some new releases today and it's 0.5.19
now - should fix this error.
thanks @IvyZX ! it's getting quite scary... it's 0.5.20 now! guess it's probably best to wait for the things to stabilise