Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

Home Page:https://lightning-flash.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Too much RAM usage by ImageClassificationData

ethanwharris opened this issue · comments

Discussed in #1442

Originally posted by Hravan September 1, 2022
I'm setting up a training for this kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8
(I'm using here only samples with single labels to make the problem simpler)

The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData.

Code shared by both training versions:

import pandas as pd
from skimage import io
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T


class PlantDataset(Dataset):
    def __init__(self, df, transform=None) -> None:
        super().__init__()
        self.img_paths = df["image"].tolist()
        self.transform = transform
        self.encoder = OneHotEncoder()
        self.labels = (
            self.encoder.fit_transform(df["label"].values.reshape(-1, 1))
            .todense()
            .A
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = io.imread(self.img_paths[idx])
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[idx]
        # return {
        #    "input": img,
        #    "target": torch.tensor(label, dtype=torch.uint8),
        # }
        return img, torch.tensor(label, dtype=torch.float32)


def preprocess_df(csv_path, images_root):
    df = pd.read_csv(csv_path)
    df = df[~df["labels"].str.contains(" ")]
    df["image"] = images_root + df["image"]
    df = df.rename(columns={"labels": "label"})
    return df


def split_df(df, train_pct):
    df = df.sample(frac=1)
    n_train = int(train_pct * len(df))
    train_df = df.iloc[:n_train].reset_index()
    val_df = df.iloc[n_train:].reset_index()
    return train_df, val_df


def create_dataloader(df):
    train_compose = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    dataloader = torch.utils.data.DataLoader(
        PlantDataset(df, transform=train_compose),
        batch_size=32,
        num_workers=8,
        prefetch_factor=8,
    )
    return dataloader

Training in plain PyTorch:

def train(model, data_loader, n_epochs):
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    for i in range(n_epochs):
        for images, labels in tqdm.tqdm(data_loader):
            images = images.cuda()
            preds = model(images)
            loss = loss_fn(preds, labels.cuda())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"End of epoch {i}")


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = torchvision.models.resnet18()
    model.fc = torch.nn.Linear(512, 6)

    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    train_loader = create_dataloader(train_df)
    time0 = perf_counter()
    train(model, train_loader, 2)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

Training in Lightning Flash:

class Resnet18(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18()
        self.model.fc = torch.nn.Linear(512, 6)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch["input"], batch["target"]
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters())


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = Resnet18()
    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    datamodule = ImageClassificationData.from_data_frame(
        "image",
        "label",
        train_data_frame=train_df,
        batch_size=32,
        transform_kwargs=dict(image_size=(224, 224)),
        num_workers=8,
        persistent_workers=True,
        pin_memory=False,
    )

    time0 = perf_counter()
    trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count())
    trainer.fit(model, datamodule=datamodule)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.

@ethanwharris, I can take a look if this is open. Seems interesting that there is such a bottleneck. Could you give me a bit more details ?

Maybe we can test this on a smaller dataset like CIFAR and see if that's the case.

Hey @Atharva-Phatak thanks for the offer! Please feel free to take a look 😃 I think a great starting point would be to have a model in Flash (trained on e.g. CIFAR-10 as you suggested) and the equivalent model just using Lightning to see if the maximum batch size you can get is different on each. If it is different then that would confirm we have a leak

@ethanwharris Sorry, I was busy with college and working on a PR for bolts. I will look at this week and let's where we can go from here :)

@Atharva-Phatak that would be great is you can still have look at it... 🐰