google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TFDS Data Processing Pipline

LeoXinhaoLee opened this issue · comments

Hi, I'm trying to understand some details in the TFDS data processing pipeline in your repo, and I'm confused about the following details:

In _tfds_data_processing.py:

(1) The truncate_to_max_allowable_length function truncates each sequence to be less than max_length. Does this mean we will simply discard and waste the text that exceeds max_length?

(2) The shuffle_buffer_size is default to 1024, which seems to be very small compared to the total number of sequences in a modern dataset.

(3) The dataset is first shuffled and then repeated for num_epochs time. But shouldn't we first repeat and then shuffle to guarantee different orders in different epochs?

In sequence_packing.py:

(4) The comment says pack_dataset.py does greedy sequence packing? What does that mean? Does it pack portions from different sequences that exceed max_length into a new sequence?

(5) map_fn is commented as "Internal function to flat_map over". Is this code working well externally?

Thank you very much for your time and help!

We would really appreciate it if we could have your timely response on these data processing protocols, as they are the cornerstone of our development with MaxText. Thank you so much for your help! @rwitten

@A9isha can you filed this one?

@LeoXinhaoLee are you in contact with us internally? We can help you there.

@LeoXinhaoLee can you have your customer engineer get us in chat via google chat? Seems like you're a paying customer and there are probably some things to discuss

Thanks a lot for raising these @LeoXinhaoLee

(1) The truncate_to_max_allowable_length function truncates each sequence to be less than max_length. Does this mean we will simply discard and waste the text that exceeds max_length?
True, this is the current implementation.

(2) The shuffle_buffer_size is default to 1024, which seems to be very small compared to the total number of sequences in a modern dataset.
Grain, a new feature in MaxText for deterministic input pipeline, can be helpful here for a better shuffling experience in a distributed system.
@aireenmei for visibility and adding more info

(3) The dataset is first shuffled and then repeated for num_epochs time. But shouldn't we first repeat and then shuffle to guarantee different orders in different epochs?
We really didn't consider that use case because our training was one epoch at the time. In your usecase, if different epochs should be shuffled differently please do that.

(4) The comment says pack_dataset.py does greedy sequence packing? What does that mean? Does it pack portions from different sequences that exceed max_length into a new sequence?
It packs multiple sequences together into a single sequence until it combined length exceeds max_length

(5) map_fn is commented as "Internal function to flat_map over". Is this code working well externally?
Ah, sorry for the confusion. The "internal function" just means a function used with map() here https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py#L206

For (2), if you are looking for something more similar to a global shuffle, you may be interested in the grain data processing pipeline, which is based on array_record data format that supports random access.

  • For ArrayRecord and beam module for converting data to ArrayRecord: https://github.com/google/array_record
  • Grain usage example: end_to_end/test_convergence_1b_params.sh (when "$DATASET_TYPE" == "c4-array_record" is the case for grain)
  • If your dataset live in gcs, you need to mount the bucket to local path with gcsfuse using setup_gcsfuse.sh

Closing this because the questions are addressed. Feel free to reopen or contact us directly!