google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sharding options with grain

bgyoon opened this issue · comments

I think the following changes need to be made to work around google/grain#351
image

Thanks for reporting the issue, @bgyoon. I see the confusion of where sharding_options should be provided. I'm checking with the grain team on this

I was wrong on the grain issue, but I would still encourage making the change since ShardLazyDataset is deprecated and sharding is being moved from Sampler into Dataloader, proper.

Also, if you shard using IndexSampler, it's not a global shuffle. It's shuffled within a shard, which may not be an issue if your ArrayRecord is already shuffled(ours wasn't):

from grain._src.core import sharding
from grain._src.python import data_loader as data_loader_lib
from grain._src.python import samplers
from grain._src.python.data_sources import ArrayRecordDataSource

if __name__ == "__main__":
    data_source = ArrayRecordDataSource([
        str("grain/grain/_src/python/testdata/digits.array_record-00000-of-00002"),
        str("grain/grain/_src/python/testdata/digits.array_record-00001-of-00002"),
    ])

    data_source[0]  # pylint: disable=pointless-statement

    print("Shard in Sampler")
    for shard in range(3):
        print("Shard: ", shard)

        sampler = samplers.IndexSampler(
            num_records=len(data_source), shard_options=sharding.ShardOptions(shard, 3, True), shuffle=True, seed=0, num_epochs=1,
        )

        num_workers = 2
        data_loader = data_loader_lib.DataLoader(
            data_source=data_source, sampler=sampler, worker_count=num_workers
        )
        for x in data_loader:
            print(x)

    print("Shard in Dataloader")
    for shard in range(3):
        print("Shard: ", shard)
        sampler = samplers.IndexSampler(
            num_records=len(data_source), shard_options=sharding.NoSharding(), shuffle=True, seed=0, num_epochs=1,
        )

        num_workers = 2
        data_loader = data_loader_lib.DataLoader(
            data_source=data_source, sampler=sampler, worker_count=num_workers, shard_options=sharding.ShardOptions(shard, 3, True)
        )
        for x in data_loader:
            print(x)
Shard in Sampler
Shard:  0
WARNING:absl:Dropping 1 examples of 10 examples (shard 3).
b'0'
b'2'
b'1'
Shard:  1
WARNING:absl:Dropping 1 examples of 10 examples (shard 3).
b'3'
b'5'
b'4'
Shard:  2
WARNING:absl:Dropping 1 examples of 10 examples (shard 3).
b'6'
b'8'
b'7'
Shard in Dataloader
Shard:  0
b'0'
b'7'
b'1'
b'2'
Shard:  1
b'4'
b'8'
b'9'
Shard:  2
b'6'
b'3'
b'5'

Shard in Sampler messed up our training. Moving that into Dataloader seems to have fixed the issue. Since I fixed my issue by making the change, it is no longer an issue for me but I figured this information may be useful to some.

Hi @bgyoon , thanks for sharing the example. Looks like the main difference between "Shard in Sampler" and "Shard in Dataloader" in the above example is data order. Can you help me understand why the order produced by "Shard in Sampler" messed up with your training?

I experienced numerous grad_norm spikings before the change which was fixed after the change.
image

I suspected double sharding, which turned out to be incorrect. But still, the global shuffle vs shuffle within a shard issue stands. I conjecture that since our ArrayRecord files are not shuffled, only shuffling within a shard is causing sudden drift in data distribution, which causes grad_norm spikes and eventually messes up training.

For example, say we have list of ArrayRecords whose first portion is English, and the later portion is Korean. If it's even-split and shufffled, like IndexSampler is doing, when batch is constructed of size 1024, first 512 elements will be always English, and last 512 elements will always be Korean. I'd guess that wouldn't be considered well shuffled.

FYI, I also tried shuffling(with fixed seed, of course) list of ArrayRecord that feeds into ArrayRecordSource, but the issue persisted. I guess even though the file is shuffled, it was not global-shuffle enough. Also, our ArrayRecord files have different file size(for example, English file is much larger than Japanese file). Since shuffling list of files alone won't shuffle individual records within, it was not good enough. I could be wrong on this again though.

Considering, 1) this change fixed our issue 2) the theory sounds plausible 3) all of your document suggests you should be moving away from sharding with sampler, I'd recommend you make the same change.

Thanks for the details! It does looks like sharding in dataloader is the way moving forward. I'll run more tests on my side and make the change. Also I noticed drop_remainder = True doesn't work when passed into dataloader. Will fix that.