huggingface / nanotron

Minimalistic large language model 3D-parallelism training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature] Refactor `ParallelContext.world_rank_matrix`

NouamaneTazi opened this issue · comments

For now we're storing the global ranks inside the world_rank_matrix attribute which is a numpy array of shape (expert_parallel_size, pipeline_parallel_size, data_parallel_size, tensor_parallel_size)

So in order to access a process' global rank using the world_rank_matrix right now we're using:

parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.expert_pg),
get_pp_rank_of(target, module=mdl),
dist.get_rank(parallel_context.dp_pg),
dist.get_rank(parallel_context.tp_pg),
],

It would be cool to make it a functional call instead such as:

parallel_context.get_global_rank(expert_parallel_rank=0, pipeline_parallel_rank=0, data_parallel_rank=0, tensor_parallel_rank=0)