google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AttributeError: 'Config' object has no attribute 'jax_coordination_service'

EddieCunningham opened this issue · comments

I am using orbax to checkpoint my models, but am getting the error when I call checkpoint_manager.save:

AttributeError                            Traceback (most recent call last)
[
File [.../lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:465](.../lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:465), in CheckpointManager.save(self, step, items, save_kwargs, metrics, force)
    400 def save(self,
    401          step: int,
    402          items: Union[Any, Mapping[str, Any]],
   (...)
    405          metrics: Optional[PyTree] = None,
    406          force: Optional[bool] = False) -> bool:
    407   """Saves the provided items.
    408 
    409   This method should be called by all hosts - process synchronization and
   (...)
    463     ValueError: if the checkpoint already exists.
    464   """
--> 465   if not force and not self.should_save(step):
    466     return False
    467   if self.reached_preemption(step):

File [.../python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:351](.../python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:351), in CheckpointManager.should_save(self, step)
...
--> 336       jax.config.jax_coordination_service
    337       and multihost_utils.reached_preemption_sync_point(step)
    338   )

AttributeError: 'Config' object has no attribute 'jax_coordination_service'

Here is code to reproduce the bug:

path = 'tmp/my_checkpoint'
options = CheckpointManagerOptions(max_to_keep=1, create=True)
checkpoint_manager = CheckpointManager(directory=path,
                                       checkpointers=PyTreeCheckpointer(),
                                       options=options)
pytree = {'a': 1, 'b': 2}
step = 1
checkpoint_manager.save(step, pytree)

From this JAX issue it looks like jax_coordination_service has been removed. Thanks in advance!

Hi, please make sure you have the latest version of orbax-checkpoint. That code should have been removed sometime in early August or so.

I had an outdated version, thanks!