RunpeiDong / DreamLLM

[ICLR 2024 Spotlight] DreamLLM: Synergistic Multimodal Comprehension and Creation

Home Page:https://dreamllm.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Could yo release the code for preparing MMC4-core dataset to get correct tarfiles?

pkulwj1994 opened this issue · comments

Hi Runpei,

Great appreciation for your wonderful work.

I notice that the MMC4-core dataset (i.e. unified interleaved dataset) in your code takes tarfiles which contain raw images and text. However, in previous works such as OpenFlamingo, the tarfiles for the webdataset contain only JSON files that record the images as base64 form.

Can you release the code for preparing the MMC4-core dataset in your omni framework? That will help the community re-implement the work and definitely will strengthen your impacts !

Thank you.
Best wishes.

Hi @pkulwj1994,

Thank you for your interest! Here is a basic script for processing interleaved datasets like MMC4. For other interleaved datasets, if you first process them to be aligned with MMC4 format, then this script will also work, which is actually our practice. Please feel free to reach out if you have any questions.

import os
import json
import argparse
import datetime
import copy
from PIL import Image
from webdataset import TarWriter
import glob
import refile
import re
import time
import multiprocessing
from tqdm import tqdm
from joblib import Parallel, delayed


DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_DREAM_TOKEN = "<dream>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_DESICION_TOKEN = "<dream_or_not>" # NOTE dream or not to dream? it's not a question!
DEFAULT_DREAM_START_TOKEN = "<dream_start>" # NOTE make llm dream!
DEFAULT_DREAM_END_TOKEN = "<dream_end>" # NOTE make llm dream!


def save_to_tar(save_path, filename, stream, info):
    # for idx, image_info in enumerate(info["image_info"]):
    #     image_name = image_info["image_name"]
    #     if not image_name.endswith("jpg"):
    #         image_info["image_name"] = image_name + ".jpg"

    json_sample = {"__key__": filename, "json": info}
    # stream.write(json_sample)
    image_dict_list = []
    # file_name = info["image_info"][0]["image_name"].split("-")[0]
    
    for idx, image_info in enumerate(info["image_info"]):
        image_name = image_info["image_name"]
        if image_name.endswith("jpg"):
            image_path = save_path + image_name
        else:
            image_path = save_path + image_name + ".jpg"

        img = Image.open(image_path, "r").convert('RGB')

        img_sample = {"__key__": f"{filename}-{idx}", "jpg": img}
        image_info["image_name"] = f"{filename}-{idx}"
        image_dict_list.append(img_sample)

        # stream.write(img_sample)
    stream.write(json_sample)
    for img_sample in image_dict_list:
        stream.write(img_sample)


def main(tar_path, tokenizer, dataset_name, use_folder=False):
    
    begin_time = time.time()

    if use_folder:
        tar_folder = tar_path.split("/")[-2]
    else:
        tar_folder = None
    
    tar_name = tar_path.split("/")[-1].split(".")[0]
    
    if tar_folder is not None:
        save_path = f'/data/save_root/{dataset_name}/{tar_folder}/{tar_name}/'
        tar_save_path = f"/data/save_tar_root/{dataset_name}-tarfiles-with-tokens_imgpatch_400_tokenlen_6k/{tar_folder}/"
        done_file_path = f"/data/raw_data/{dataset_name}/{tar_folder}/"
    else:
        save_path = f'/data/save_root/{dataset_name}/{tar_name}/'
        tar_save_path = f"/data/save_tar_root/{dataset_name}-tarfiles-with-tokens_imgpatch_400_tokenlen_6k/"
        done_file_path = f"/data/raw_data/{dataset_name}/"
    done_file = os.path.join(done_file_path, f"{tar_name}.json")
    if os.path.exists(done_file):
        return

    os.makedirs(save_path, exist_ok=True)
    os.makedirs(tar_save_path, exist_ok=True)

    s3_path = tar_path
    
    END_POINT_URL = "http://xxx" # Set your aws endpoint url here
    os.system(f'aws --endpoint-url {END_POINT_URL} s3 cp {s3_path} {save_path}') # s3 cp
    os.system(f'tar -xvf {save_path + tar_name + ".tar"} -C {save_path}')
    os.system(f'tar -xvf {s3_path} -C {save_path}')
    os.system(f"rm {save_path}/{tar_name}.tar")

    fname = tar_save_path + tar_name + '.tar'
    if os.path.exists(fname):
        os.system(f"rm {fname}")

    stream = TarWriter(fname)

    json_list = glob.glob(f"{save_path}/*.json")

    valid_cnt = 0
    total_cnt = len(json_list)
    merged_samples_cnt = 0
    cached_data = []
    cached_file_name = ""
    cached_token_len = 0
    for json_file in tqdm(json_list):
        sample = json.load(open(json_file, 'r'))
        if "image_info" not in sample.keys() or len(sample["image_info"]) == 0:
            continue

        filename = json_file.split("/")[-1].split(".")[0]

        # img_path = save_path + filename + '.jpg'
        # json_path = save_path + filename + '.json'
        # txt_path = save_path + filename + '.txt'

        info = sample # json.load(open(json_path))
        url_pattern = r'\[([^\]]+)\]\((http[s]?://[^\)]+)\)'
        info["text_list"] = [re.sub(url_pattern, r'\1', text) for text in info.pop("text_list")]
        new_text_list = [copy.deepcopy(text) for text in info["text_list"]]
        if new_text_list[-1] == "":
            new_text_list = [""] + new_text_list[:-1]
        else:
            new_text_list = [""] + new_text_list

        for image_info in info["image_info"]:
            image_name = image_info["image_name"]
            if image_name.endswith("jpg"):
                image_path = save_path + image_name
            else:
                image_path = save_path + image_name + ".jpg"
            # if not os.path.exists(image_path):
            #     image_info = None
            #     continue
            matched_text_index = image_info["matched_text_index"]
            new_text_list[matched_text_index] = new_text_list[matched_text_index] + DEFAULT_DREAM_TOKEN + DEFAULT_IMAGE_TOKEN
            
        all_text = " ".join(new_text_list).strip()
        replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 400
        replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN

        generation_replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 64
        generation_replace_token = DEFAULT_DREAM_START_TOKEN + generation_replace_token + DEFAULT_DREAM_END_TOKEN

        all_text = all_text.replace(DEFAULT_IMAGE_TOKEN, replace_token)
        all_text = all_text.replace(DEFAULT_DREAM_TOKEN, generation_replace_token)

        all_text = DEFAULT_BOS_TOKEN + all_text + DEFAULT_EOS_TOKEN

        try:
            if tokenizer is not None:
                token = tokenizer.tokenize(all_text)
            else:
                token = None
        except:
            print(all_text)
            continue

        if token is None:
            cached_json = cached_data[0]
            for data in cached_data[1:]:
                cached_json["text_list"].extend(data["text_list"])
                cached_json["image_info"].extend(data["image_info"])
            save_to_tar(save_path, cached_filename, stream, cached_json)
            merged_samples_cnt += 1

            cached_data = []
        else:
            info["token"] = token
            if cached_token_len + len(token) > 6000 and len(cached_data) > 0:
                cached_json = cached_data[0]
                for data in cached_data[1:]:
                    cached_json["text_list"].extend(data["text_list"])
                    cached_json["image_info"].extend(data["image_info"])
                    cached_json["token"].extend(data["token"][1:])
                save_to_tar(save_path, cached_filename, stream, cached_json)
                merged_samples_cnt += 1

                cached_data = []
                cached_token_len = 0

            if len(token) < 6000:
                if len(cached_data) == 0:
                    cached_filename = filename
                cached_data.append(info)
                cached_token_len += len(token)

        valid_cnt += 1

    if len(cached_data) > 0:
        cached_json = cached_data[0]
        for data in cached_data[1:]:
            cached_json["text_list"].extend(data["text_list"])
            cached_json["image_info"].extend(data["image_info"])
            if token is not None:
                cached_json["token"].extend(data["token"][1:])
        save_to_tar(save_path, cached_filename, stream, cached_json)
        merged_samples_cnt += 1

    stream.close()
    
    json.dump({"url": fname, "nsamples": merged_samples_cnt}, open(done_file, 'w'), indent=2)
    # with open(done_file, "w") as f:
    #     f.write(fname + "\n")
    #     f.write(str(merged_samples_cnt))
    #     pass
    # os.system(f'aws --endpoint-url {END_POINT_URL} s3 cp {fname} s3://tar_root/{tar_folder}/{tar_name}.tar')
    
    print(f"[{merged_samples_cnt, valid_cnt}] (total [{total_cnt}]) complete to write samples to shard {fname}, and it takes {time.time() - begin_time}s")
    
    os.system(f'rm -rf {save_path}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # fire.Fire(caption_all)
    parser.add_argument("--machine_id",type=int, default=0)
    parser.add_argument('--num_process', type=int, default=32)
    parser.add_argument('--num_machine', type=int, default=1)
    parser.add_argument('--dataset_name', type=str, default="xiachufang")
    
    args = parser.parse_args()
    parser.parse_args()

    dataset2url = dict(
        mmc4="s3://vision-language-data/mmc4",
        obelisc="s3://vision-language-data/oblisc-tarfiles/0",
    )
    
    dataset_name = args.dataset_name
    total_machine = args.num_machine
    machine_id = args.machine_id
    tar_list = refile.smart_glob(f"{dataset2url[dataset_name]}/*.tar")
    print(f"total tar list num: {len(tar_list)}")
    tar_list = tar_list[machine_id::total_machine]

    # optional: use tokenizer to tokenize texts
    # tokenizer_model = "data/tokenizers/tokenizer.model"
    tokenizer = None
    # main(tar_list[0], tokenizer, dataset_name)
    # # for tar_file in tar_list:
    # #     main(tar_file, tokenizer, dataset_name)
    if dataset_name == "obelisc":
        use_folder = True
    else:
        use_folder = False
    Parallel(n_jobs=args.num_process)(
        delayed(main)(tar_file, tokenizer, dataset_name, use_folder) for tar_file in tar_list
    )

Hi Runpei,

It looks good. But it seems that you have set the tokenizer to be None, and this let the script does not save tar at all.

Could you explain a little bit or show the code (could be the best) how do you load the tokenizer? or possibly how to write to tar files without tokenizer?

Best,
Weijian.

Hi @pkulwj1994,

That is because it is not necessary. As you can see from the dataloader, we have provided two versions of data processing. One is when the data is pre-tokenized (if you are running with a certain tokenizer without changing it), and another is when the raw texts are input and tokenized online. So, the way to load the tokenizer depends on whether you are certain what tokenizer you want to use. Setting it to None is the way that write tar files without pre-tokenized tokens.

But in your provided code, setting the tokenizer to be None seems to write nothing into the tar files.
Do I misunderstand your code?

Best.

Oh yes. Sorry for the confusion. That's a mistake when I modify the code to the non-tokenizer setting. You can ignore and remove all token-related information while writing other information.

Btw, you can send me an email so that we may talk more through the media like WeChat if you are interested.

Sounds good.

Really appreciate your work.

Best,
Weijian.

Sounds good.

Really appreciate your work.

Best, Weijian.

You are welcome. Always happy to help.