[Feature] Refactor `ParallelContext.world_rank_matrix`
NouamaneTazi opened this issue · comments
Nouamane Tazi commented
For now we're storing the global ranks inside the world_rank_matrix
attribute which is a numpy array of shape (expert_parallel_size, pipeline_parallel_size, data_parallel_size, tensor_parallel_size)
So in order to access a process' global rank using the world_rank_matrix right now we're using:
nanotron/tests/test_parameters_accumulate_gradient_in_fp32.py
Lines 346 to 351 in 7c01d0f
It would be cool to make it a functional call instead such as:
parallel_context.get_global_rank(expert_parallel_rank=0, pipeline_parallel_rank=0, data_parallel_rank=0, tensor_parallel_rank=0)