google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Ask for Advice] Enforcing an expensive function to be only called in host during training loop

riven314 opened this issue · comments

I am one of the participant of Huggingface Community Week. I got a question about pmap and training loop, would really appreciate if someone could offer me some advice on that:

I have an memory & time expensive step that needa run per N iterations, e.g.:

for it, batch in enumerate(dataset):
  state = pmap_train(keys, model, state, batch, lr)
  if it % N == 0:
     run_expensive_step(keys, model, state, batch, clip_model, lr)

Questions

  1. will this run_expensive_step create a big bottleneck on pmap_train call, or it wont be blocked (maybe becoz of async nature?)...
  2. is there any way to enforce run_expensive_step to be called only on the host core? (coz my host have 300+GB and run_expensive_step is super memory hungry)
  3. model has been passed through flax.jax_utils.replicate and batch has been passed through shard, but clip_model haven't been flax.jax_utils.replicate. Do I needa reduce model and batch dimension before each run_expensive_step call?
  1. Yes this would be blocking - run_expensive_step() needs to wait for the computation of state to finish first before transferring it to the host memory. And then your training loop is blocked until run_expensive_step() finishes. If that's an issue you could run run_expensive_step() in a separate thread.

  2. Inside run_expensive_step() you could use numpy for the computation. That would automatically copy the values to host memory and use CPU for computations. Alternatively, you could jax.jit(run_expensive_step, backend='cpu').

  3. In either case you would need to flax.jax_utils.unreplicate() the pytrees first to get rid of the leading device dimension. As for batch you would need to transform it via jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], batch) or similar (i.e. reshaping, not slicing).