google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Questions about loading and writing checkpoints in distributed training

hr0nix opened this issue · comments

Hi! I have a few questions about how orbax handles certain distributed training scenarios.

  • Imagine I have a model that is sharded over one mesh dimension (possibly over multiple nodes) while being replicated over another dimension. Is it true that, when writing a checkpoint, only one subset of nodes having a full replica of model weights will perform a writing operation? If so, what determines this subset?
  • Suppose my model is replicated over a number of nodes. When loading checkpoint, will each replica read weights independently in parallel?
    • If so, it might create a bottleneck if checkpoints are being read from a network filesystem with limited bandwidth. One alternative would be for only one replica to read the weights and then send them to other replicas using communication collectives, assuming inter-node network is fast. Is this something that can be done with orbax?

Thanks!

Firstly, correct me if I'm wrong, but here's an example of what you're talking about in concrete terms:

assert len(jax.devices()) == 16
assert jax.process_count() == 4
mesh = jax.sharding.Mesh(jax.devices().reshape((4, 4)), ('x', 'y'))
spec = ('x',)
sharding = jax.sharding.NamedSharding(mesh, spec)
# create array...

Model arrays will be sharded such that shards A, B, C, D are present on each host, but each of the local devices will have one of A, B, C, or D. Then, only the host that has shards with replica_id == 0 will actually save their shard. replica_id is determined by JAX, so I'm not sure how that gets set. And yes, it would be a subset of the processes that are actually doing I/O. In this example, it would be good if replica_id 0 was on a different host for each of the shards, but I'm not actually sure if that's the case. We already have this as a TODO to distribute the saving of shards more equitably among available hosts.

Your second question is well aligned with a concern that we have also noted recently - every host is indeed reading the entire data in parallel. We've been working with MaxText on a solution that only loads on a single host and then broadcasts to the other hosts. Eventually we envision upstreaming something like this into core Orbax. In the meantime, you should check with rwitten and the other owners to get the definitive version of the solution I developed for them - should be easy to reflect into your own code if that's what you want.

Thanks for the answer!

Is there a MaxText PR that implements this functionality? I've found this one, but can't find where the broadcasting is happening.

Yes there is, I haven't heard back from them yet about what happened to it. The one you linked just deals with reading the steps on a single host and then broadcasting, rather than the actual arrays.

It's this commit: google/maxtext@76539a2

I think the PyTreeCheckpointHandler overrides are not really needed, that was just to add logging.