Enable Sharding to Equal Sized Shards
yuvalkirstain opened this issue · comments
Feature request
Add an option when sharding a dataset to have all shards the same size. Will be good to provide both an option of duplication, and by truncation.
Motivation
Currently the behavior of sharding is "If n % i == l, then the first l shards will have length (n // i) + 1, and the remaining shards will have length (n // i).". However, when using FSDP we want the shards to have the same size. This requires the user to manually handle this situation, but it will be nice if we had an option to shard the dataset into equally sized shards.
Your contribution
For now just a PR. I can also add code that does what is needed, but probably not efficient.
Shard to equal size by duplication:
remainder = len(dataset) % num_shards
num_missing_examples = num_shards - remainder
duplicated = dataset.select(list(range(num_missing_examples)))
dataset = concatenate_datasets([dataset, duplicated])
shard = dataset.shard(num_shards, shard_idx)
Or by truncation:
shard = dataset.shard(num_shards, shard_idx)
num_examples_per_shard = len(dataset) // num_shards
shard = shard.select(list(range(num_examples_per_shard)))