google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Orbax API migration questions

haohuanw opened this issue · comments

i have migrated my use cases to use the new orbax api but have some questions (maybe due to my setup is incorrect).

issue1: i am getting

ValueError: Distributed system is not available; please initialize it via `jax.distributed.initialize()` at the start of your program.

when calling

orbax/checkpoint/async_checkpointer.py", line 83, in _get_barrier_sync_fn
    raise ValueError(

i am using single slice multi host tpu on gke, which i think calling distributed.initilize() is not needed..? but let me know if that's not the case...

issue 2: my current setup is still using pmap (i know this is not ideal 😢 ) with flax.jax_utils.replicate and flax.jax_utils.unreplicate and i am getting

ValueError: Cannot serialize host local arrays. Arrays like this are typically obtained using pmap. Consider using fully_replicated_host_local_array_to_global_array in orbax/checkpoint/utils.py to convert your arrays into serializable objects.

when training on multiple tpu. just to confirm, i basically just need to convert entire state to gda array? previously, i think orbax only save weights on process 0 automatically and i want to make sure the behavior is the same.

For 1, I'm also not always sure when jax.distributed.initialize is getting called and when not, but basically it always needs to be called. The only question is whether JAX is doing the call for you behind the scenes or not, but it is essentially always needed. I'm just not quite sure why it's not getting called for you.

For 2, if all the arrays in your state were obtained from pmap, you would need to use fully_replicated_host_local_array_to_global_array on all of them to get jax.Array with the appropriate sharding. The arrays would still be fully replicated, so saving would happen on one process.

For 1. I believe jax.distributed.initialize() is only called automatically inside of google, open source users of jax must call it explicitly. I will work on a fix for jax.distributed.initialize() by modifying how we automate arguments here - currently it looks like this code only works for GKE on XPK which is a user-friendly wrapper of GKE

As an immediate workaround you can remove the jax.distributed.initialize() call and use an orbax synchronous checkpointer instead of asychronous, or you can set MEGASCALE_COORDINATOR_ADDRESS= as an env var - I think this solution has already been recommended to you. Edit: this code shows calling jax.distributed.initialize with explicit inputs

thanks! this is the current workaround i have.