[Ask for Advice] Enforcing an expensive function to be only called in host during training loop
riven314 opened this issue · comments
I am one of the participant of Huggingface Community Week. I got a question about pmap and training loop, would really appreciate if someone could offer me some advice on that:
I have an memory & time expensive step that needa run per N iterations, e.g.:
for it, batch in enumerate(dataset):
state = pmap_train(keys, model, state, batch, lr)
if it % N == 0:
run_expensive_step(keys, model, state, batch, clip_model, lr)
Questions
- will this
run_expensive_step
create a big bottleneck onpmap_train
call, or it wont be blocked (maybe becoz of async nature?)... - is there any way to enforce run_expensive_step to be called only on the host core? (coz my host have 300+GB and
run_expensive_step
is super memory hungry) model
has been passed throughflax.jax_utils.replicate
andbatch
has been passed throughshard
, butclip_model
haven't beenflax.jax_utils.replicate
. Do I needa reducemodel
andbatch
dimension before eachrun_expensive_step
call?
-
Yes this would be blocking -
run_expensive_step()
needs to wait for the computation ofstate
to finish first before transferring it to the host memory. And then your training loop is blocked untilrun_expensive_step()
finishes. If that's an issue you could runrun_expensive_step()
in a separate thread. -
Inside
run_expensive_step()
you could usenumpy
for the computation. That would automatically copy the values to host memory and use CPU for computations. Alternatively, you couldjax.jit(run_expensive_step, backend='cpu')
. -
In either case you would need to
flax.jax_utils.unreplicate()
the pytrees first to get rid of the leading device dimension. As forbatch
you would need to transform it viajax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], batch)
or similar (i.e. reshaping, not slicing).