Problematic rank after calling `split_dataset_by_node` twice
yzhangcs opened this issue · comments
Describe the bug
I'm trying to split IterableDataset
by split_dataset_by_node
.
But when doing split on a already split dataset, the resulting rank
is greater than world_size
.
Steps to reproduce the bug
Here is the minimal code for reproduction:
>>> from datasets import load_dataset
>>> from datasets.distributed import split_dataset_by_node
>>> dataset = load_dataset('fla-hub/slimpajama-test', split='train', streaming=True)
>>> dataset = split_dataset_by_node(dataset, 1, 32)
>>> dataset._distributed
DistributedConfig(rank=1, world_size=32)
>>> dataset = split_dataset_by_node(dataset, 1, 15)
>>> dataset._distributed
DistributedConfig(rank=481, world_size=480)
As you can see, the second rank 481 > 480, which is problematic.
Expected behavior
I think this error comes from this line @lhoestq
datasets/src/datasets/iterable_dataset.py
Lines 2943 to 2944 in a6ccf94
We may need to obtain the rank first. Then the above code gives
>>> dataset._distributed
DistributedConfig(rank=16, world_size=480)
Environment info
datasets==2.20.0
ah yes good catch ! feel free to open a PR with your suggested fix