EleutherAI / pythia

The hub for EleutherAI's work on interpretability and learning dynamics

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Pythia on Pile-Dedup] Training for ~1.5 epochs: how to identify the repeated sequences (i.e., the additional .5 epoch)?

pietrolesci opened this issue · comments

Hi there,

The deduplicated dataset has fewer sequences and to keep a consistent token count with the non-deduplicated version the models are trained for ~1.5 epochs (as discussed in the README). Between epochs, are the data reshuffled or simply the dataloader starts from the beginning again in the same order? If the latter is the case, is there a way to know exactly which checkpoint is the first to see the same data twice? Put differently, is there a way to know which sequences are seen by the model in the additional ~half epoch?

Thanks a lot in advance for your help!

cc @haileyschoelkopf

Hey! I had similar questions a while back for a paper in which we used the Pythia suite—to the best of my understanding, the answers are that it's 1.5 epochs, where about the first half of the data (same order) is seen twice. The Pythia paper describes how many total tokens the models see and how many it sees in the first pass; based on those numbers, I use the step98000 checkpoint as my full "single pass" checkpoint. I believe the checkpoints after start "seeing double."

Thanks a lot for your answer @jeffreygwang, this seems reasonable to me too!

Between epochs, are the data reshuffled or simply the dataloader starts from the beginning again in the same order?

The answer seems to be that the dataloader does NOT simply start from the beginning again. It means that the concatenation happened at the document level, that is before the "packing" process. This means the initial tokens can appear in different positions within a sequence.