Orbax API migration questions
haohuanw opened this issue · comments
i have migrated my use cases to use the new orbax api but have some questions (maybe due to my setup is incorrect).
issue1: i am getting
ValueError: Distributed system is not available; please initialize it via `jax.distributed.initialize()` at the start of your program.
when calling
orbax/checkpoint/async_checkpointer.py", line 83, in _get_barrier_sync_fn
raise ValueError(
i am using single slice multi host tpu on gke, which i think calling distributed.initilize()
is not needed..? but let me know if that's not the case...
issue 2: my current setup is still using pmap
(i know this is not ideal 😢 ) with flax.jax_utils.replicate
and flax.jax_utils.unreplicate
and i am getting
ValueError: Cannot serialize host local arrays. Arrays like this are typically obtained using pmap. Consider using fully_replicated_host_local_array_to_global_array in orbax/checkpoint/utils.py to convert your arrays into serializable objects.
when training on multiple tpu. just to confirm, i basically just need to convert entire state to gda array? previously, i think orbax only save weights on process 0 automatically and i want to make sure the behavior is the same.
For 1, I'm also not always sure when jax.distributed.initialize
is getting called and when not, but basically it always needs to be called. The only question is whether JAX is doing the call for you behind the scenes or not, but it is essentially always needed. I'm just not quite sure why it's not getting called for you.
For 2, if all the arrays in your state were obtained from pmap, you would need to use fully_replicated_host_local_array_to_global_array
on all of them to get jax.Array
with the appropriate sharding. The arrays would still be fully replicated, so saving would happen on one process.
For 1. I believe jax.distributed.initialize()
is only called automatically inside of google, open source users of jax must call it explicitly. I will work on a fix for jax.distributed.initialize()
by modifying how we automate arguments here - currently it looks like this code only works for GKE on XPK which is a user-friendly wrapper of GKE
As an immediate workaround you can remove the jax.distributed.initialize()
call and use an orbax synchronous checkpointer instead of asychronous, or you can set MEGASCALE_COORDINATOR_ADDRESS
= as an env var - I think this solution has already been recommended to you. Edit: this code shows calling jax.distributed.initialize
with explicit inputs
thanks! this is the current workaround i have.