huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools

Home Page:https://huggingface.co/docs/datasets

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problematic rank after calling `split_dataset_by_node` twice

yzhangcs opened this issue · comments

Describe the bug

I'm trying to split IterableDataset by split_dataset_by_node.
But when doing split on a already split dataset, the resulting rank is greater than world_size.

Steps to reproduce the bug

Here is the minimal code for reproduction:

>>> from datasets import load_dataset
>>> from datasets.distributed import split_dataset_by_node
>>> dataset = load_dataset('fla-hub/slimpajama-test', split='train', streaming=True)      
>>> dataset = split_dataset_by_node(dataset, 1, 32)
>>> dataset._distributed
DistributedConfig(rank=1, world_size=32)
>>> dataset = split_dataset_by_node(dataset, 1, 15)
>>> dataset._distributed                           
DistributedConfig(rank=481, world_size=480)

As you can see, the second rank 481 > 480, which is problematic.

Expected behavior

I think this error comes from this line @lhoestq

world_size = world_size * dataset._distributed.world_size
rank = world_size * dataset._distributed.rank + rank

We may need to obtain the rank first. Then the above code gives

>>> dataset._distributed                           
DistributedConfig(rank=16, world_size=480)

Environment info

datasets==2.20.0

ah yes good catch ! feel free to open a PR with your suggested fix