mosaicml / llm-foundry

LLM training code for Databricks foundation models

Home Page:https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to support multi-threaded parallel data preprocessing?

YixinSong-e opened this issue · comments

I want to pretrain an LLM with 2T tokens using llm-foundry. But before training, the data processing time is too long. Is there any way to accelerate it?

Agree, this would be very useful.

Would it be possible to implement sharding for convert_dataset_json.py? Simply add extra parameters to specify # of shards and index of shard. Script could then be run on multiple machines, targeting the same output directory. I checked the code, but I am not sure how to do it with MDSWriter yet.

I think the example conversion script is perhaps not very good. One thing that helps a lot is to use the Datasets .map() to batch tokenize the dataset. I'm not sure how writing to the MDS file can be parallelized, but it probably can.

Also, there is a bug in tokenizers that might make it way slower than you would like - see huggingface/tokenizers#1413.

The text to MDS conversion script (https://github.com/mosaicml/llm-foundry/blob/main/scripts/data_prep/convert_text_to_mds.py) is parallelized, is that what you are looking for (or at least a good starting point)?

The text to MDS conversion script (https://github.com/mosaicml/llm-foundry/blob/main/scripts/data_prep/convert_text_to_mds.py) is parallelized, is that what you are looking for (or at least a good starting point)?

Thanks, I will look into it.

Isn't enough to just run the script in parallel, and merge the mds shards with this method?

def merge_shard_groups(root: str) -> None:

Currently, I am trying it like this.

I have large jsonl file. I used split -l to split it into number of procs files. Then I call convert_dataset_json.py independently on each of these, obtaining 1 output folder for each process, the output folder is in some output_root_folder.

Lastly, I hope it will be enough to just call the mentioned merge method on output_root_folder

(Will update once the progress is finished.).

Yes @MFajcik , that should work!

Isn't enough to just run the script in parallel, and merge the mds shards with this method?

def merge_shard_groups(root: str) -> None:

Currently, I am trying it like this.

I have large jsonl file. I used split -l to split it into number of procs files. Then I call convert_dataset_json.py independently on each of these, obtaining 1 output folder for each process, the output folder is in some output_root_folder.

Lastly, I hope it will be enough to just call the mentioned merge method on output_root_folder

(Will update once the progress is finished.).

Yes @MFajcik , that should work!

It does work! Preprocessing was done in notime. Training is running right now. Thanks for the hint!

I changed ConcatTokensDataset.__iter__ to this:

def __iter__(self) -> Iterable[Dict[str, bytes]]:

        buffer = []
        # self.write_batch_size = 10_000
        shards = self.hf_dataset.num_rows // self.write_batch_size + 1
        for i in range(shards):
            shard = self.hf_dataset[
                i * self.write_batch_size : (i + 1) * self.write_batch_size
            ]
            encoded_shard = self.tokenizer(
                shard["text"], truncation=False, padding=False
            )
            for encoded in encoded_shard["input_ids"]:
                iids = encoded  # ['input_ids']
                buffer = buffer + self.bos_tokens + iids + self.eos_tokens
                while len(buffer) >= self.max_length:
                    concat_sample = buffer[: self.max_length]
                    buffer = buffer[self.max_length :] if self.should_wrap else []
                    yield {
                        # convert to bytes to store in MDS binary format
                        "tokens": np.asarray(concat_sample).tobytes(),
                        "num_tokens": len(concat_sample),
                    }

Processing 7B tokens takes around 20 hours with the original code and 30 min with this change. It's not very robust though and doesn't scale very well: a fast tokenizer hangs after a while with very long text and more than 16 threads seem not to give you any speedup.

I changed ConcatTokensDataset.__iter__ to this:

def __iter__(self) -> Iterable[Dict[str, bytes]]:

        buffer = []
        # self.write_batch_size = 10_000
        shards = self.hf_dataset.num_rows // self.write_batch_size + 1
        for i in range(shards):
            shard = self.hf_dataset[
                i * self.write_batch_size : (i + 1) * self.write_batch_size
            ]
            encoded_shard = self.tokenizer(
                shard["text"], truncation=False, padding=False
            )
            for encoded in encoded_shard["input_ids"]:
                iids = encoded  # ['input_ids']
                buffer = buffer + self.bos_tokens + iids + self.eos_tokens
                while len(buffer) >= self.max_length:
                    concat_sample = buffer[: self.max_length]
                    buffer = buffer[self.max_length :] if self.should_wrap else []
                    yield {
                        # convert to bytes to store in MDS binary format
                        "tokens": np.asarray(concat_sample).tobytes(),
                        "num_tokens": len(concat_sample),
                    }

Processing 7B tokens takes around 20 hours with the original code and 30 min with this change. It's not very robust though and doesn't scale very well: a fast tokenizer hangs after a while with very long text and more than 16 threads seem not to give you any speedup.

Thanks for your update! Do you modify other files to enable multithread?

Thanks for your update! Do you modify other files to enable multithreaded?

Yes sorry, I also removed os.environ["TOKENIZERS_PARALLELISM"] = "false" from ConcatTokensDataset.__init__.

Thanks for your update! Do you modify other files to enable multithreaded?

Yes sorry, I also removed os.environ["TOKENIZERS_PARALLELISM"] = "false" from ConcatTokensDataset.__init__.

It helps a lot. I can process 100B tokens within in 7 hours with your code! :)