mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training

Home Page:https://streaming.docs.mosaicml.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How does a LRU local cache help with multi-epoch training

liangjuf opened this issue · comments

This is not technically Feature Request, but rather a question about the local cache behavior of StreamingDataset.

I am using StreamDataset to load dataset from S3 in streaming manner. The dataset is large and can't be cached entirely on the each node. At some point, StreamingDataset needs to evict some shards according to its LRU eviction strategy.

Now let's assume that I have a training job that runs two epochs. During the first epoch, StreamingDataset keeps downloading shards from s3 and cached it locally. As training continues, due to the fact that local space is limited, StreamingDataset has to evict some shards which is loaded during early stage of first epoch. (with the reason being that each shard is accessed exactly only once, the LRU eviction strategy now converts to a time based strategy). At the end of first epoch, what's left is the cache is the shards that downloads toward the end.

Now, during the second epoch, StreamingDataset needs to download early shards again because these early shards have already been evicted. Also, as training goes on, the late shards downloaded during the late stage of first epoch are also evicted from the cache, replaced by the shards downloaded during the early stage of second epoch. Consequently the second epoch downloads the whole dataset again and local cache doesn't help at all.

I am wondering if there is any improvement we can do in this scenario? With LRU eviction strategy it seems like the hit rate for local cache could be as low as 0%.

This is a great observation @liangjuf -- you're exactly right. We designed the LRU cache for scenarios where the dataset is so large that it can't all fit onto disk, meaning that you probably are not training for many epochs as well. And because of this, while the LRU cache is great for limiting disk usage intra-epoch training, it's not great when moving between epochs.

Maybe what we could do is have a sort of "snake" approach to the shard ordering during training. What I mean is that let's say you access shards in a particular order during your first epoch. Since you are using the LRU cache, the shards that are needed at the end of your epoch are still on disk. So then you draw samples for your second epoch by reversing the shard access order from the first epoch, meaning that you are able to make use of existing shards in the cache without having to evict them immediately. This would likely be a modification of the StreamingDataset.resample_streams() function here, which is what is called to get the order of shard parts, as well as the mapping from original sample ids to training sample ids. Could probably change the behavior of the function depending on if the epoch arg is even or odd :)

That being said, we likely won't get to this soon since there are some other higher priority items right now. If you wanted to take a stab at this, that would be amazing! Also, it would be great to better understand your use case and dataset size -- are you concerned about egress fees? And is your dataset very, very large?

Thank you for your quick response and pointer. That's helpful.

Our dataset is not crazily large, a couple of TB. I am exploring several dataloaders as well as the locations that data can be stored. I want to benchmark the throughput to make sure that dataloading can keep GPU busy most of time.

My current findings of StreamingDataset is that loading from S3 is significantly slower than loading data from EFS or from host filesystem. That somewhat makes sense - as downloading is expensive. I also find that if the shard is cached, the dataloading is way faster (can be up to 10x). And this leads me to this question: what can we do to make sure cache hit rate as high as possible during the whole training with multiple epochs.

Are you benchmarking throughput along with model training? Or just iterating through the dataloader? Because downloading from a remote location when iterating through just the dataloader will definitely be slower than loading from local disk. You can also allow StreamingDataset to download farther ahead by increasing the predownload arg -- this currently defaults to 8*batch_size. Also, make sure that you're setting the batch_size in your StreamingDataset the same as your DataLoader -- this batch_size is the per-device batch size.

Happy to take a look at some example code you have if you'd like as well.

Currently I am justing iterating through the dataset without the involvement of model, as the variance of model training speed could be big and for now I'd like to assume that the whole system is bottlenecked on the CPU side instead of device side.

Thanks for your answer again and feel free to close this issue.