huggingface / jat

General multi-task deep RL Agent

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Datasets loading slow

qgallouedec opened this issue · comments

The current code is roughly equivalent to:

import datasets
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

dataset_size = 1_000
elmts_per_sequence = 56

d = {
    "images": torch.randint(0, 256, (dataset_size, elmts_per_sequence, 3, 56, 56), dtype=torch.uint8),
    "tokens1": torch.randint(0, 32_000, (dataset_size, elmts_per_sequence, 19)),
    "tokens2": torch.randint(0, 32_000, (dataset_size, elmts_per_sequence, 12)),
}

d = datasets.Dataset.from_dict(d)
d.set_format(type="torch")

dataloader = DataLoader(d, batch_size=32)


for batch in tqdm(dataloader):
    tqdm.write(" ".join(str(value.shape) + " " + str(value.dtype) for value in batch.values()))

It takes around 2.5 to load a single batch.

The py-spy profile looks likes:

profile

Am I doing something wrong? If no, it seems like a limitation of datasets, and we should decide whether it's ok to have this loading time?
Or should we implement or own load/save method to avoid using arrow format?

Note that you have a approx x1000 speedup if you use only torch (but you need to handle cache manually):

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

dataset_size = 1_000
elmts_per_sequence = 56

class MyDataset(Dataset):
    def __init__(self):
        self.d = {
            "images": torch.randint(0, 256, (dataset_size, elmts_per_sequence, 3, 56, 56), dtype=torch.uint8),
            "tokens1": torch.randint(0, 32_000, (dataset_size, elmts_per_sequence, 19)),
            "tokens2": torch.randint(0, 32_000, (dataset_size, elmts_per_sequence, 12)),
        }

    def __getitem__(self, index):
        return {k: v[index] for k, v in self.d.items()}

    def __len__(self):
        return len(self.d["images"])


d = MyDataset()
dataloader = DataLoader(d, batch_size=32)

for batch in tqdm(dataloader):
    tqdm.write(" ".join(str(value.shape) + " " + str(value.dtype) for value in batch.values()))

386.41it/s