Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adam optimizer is slower after loading model from checkpoint

radomirgr opened this issue · comments

Bug description

When i was resuming my model from training from checkpoint i notice slowness in gpu utilization. I have found problem that adam is doing cuda sync after restoring from checkpoint. It is a problem if you have a lot of optimziers in your network.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

Problem is that lightning is putting all optimizer state to the gpu here

My current workaround is:

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()
        for _, vv in optimizer.state.items():
            if "step" in vv and vv["step"].device.type == "cuda":
                vv["step"] = vv["step"].cpu()

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os
from typing import Any, Tuple

import lightning.pytorch as plight
import lightning.pytorch as pl
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

num_features = 6875
num_responses = 7
batch_size = 32768


class CachedRandomTensorDataset(torch.utils.data.Dataset):
    """Very low overhead torch dataset for training for a given number of steps"""

    def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int) -> None:
        self.x = torch.randn((batch_size, num_features))
        self.y = torch.randn((batch_size, num_responses))
        self.length = length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.x.clone(), self.y.clone()

    def __len__(self) -> int:
        return self.length


dataset = CachedRandomTensorDataset(
    num_features=num_features,
    num_responses=num_responses,
    length=1013,
    batch_size=batch_size,
)

train_dataloader = DataLoader(dataset, batch_size=None, pin_memory=False, num_workers=0, shuffle=False)


class MLP(nn.Module):

    def __init__(
        self,
        in_dim,
        hidden_dim,
        out_dim,
    ):
        super().__init__()
        self.layers = len(hidden_dim)
        self.LinearClass = nn.Linear
        self.activation_fn = nn.ReLU()
        module_dict = {}
        for i in range(self.layers):
            layer_input_size = in_dim if i == 0 else hidden_dim[i - 1]
            module_dict[f"layer_{i}"] = nn.Linear(layer_input_size, hidden_dim[i])
        module_dict["last_linear"] = nn.Linear(hidden_dim[-1], out_dim)
        self.module_dict = nn.ModuleDict(module_dict)

    def forward(self, x):
        for i in range(self.layers):
            x = self.module_dict[f"layer_{i}"](x)
            x = self.activation_fn(x)
        yhat = self.module_dict["last_linear"](x)
        return yhat


class TestNetwork(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        num_it: int,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.automatic_optimization = False
        self.model = model
        self.mse = nn.MSELoss()
        self.num_it = num_it

    def configure_optimizers(self, name=None):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()

        for _ in range(self.num_it):
            torch.cuda.nvtx.range_push("it step")
            x, y = batch
            yhat = self.model.forward(x)
            loss = self.mse(yhat, y)

            optimizer.zero_grad()
            self.manual_backward(loss)
            torch.cuda.nvtx.range_push("optimizer")
            optimizer.step()
            torch.cuda.nvtx.range_pop()

            torch.cuda.nvtx.range_pop()


train_model = TestNetwork(
    MLP(
        num_features,
        [2048, 1024, 512, 256],
        num_responses,
    ),
    200,
)

trainer_max_steps = 200
checkpoint_name = "debug3"
checkpoint_dir = "./model_checkpoint"
ckpt_path = f"{checkpoint_dir}/{checkpoint_name}-step={trainer_max_steps}.ckpt"

if os.path.isfile(ckpt_path):
    print("training from checkpoint")
    trainer_max_steps = trainer_max_steps + 1
else:
    print("training new model")
    ckpt_path = None


checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=10,
    monitor="step",
    mode="max",
    filename=checkpoint_name + "-{step:02d}",
    every_n_train_steps=100,
)


# TRAINER CREATION
trainer = plight.Trainer(
    accelerator="gpu",
    devices=1,
    num_nodes=1,
    max_steps=trainer_max_steps,
    max_epochs=1,
    log_every_n_steps=50,
    logger=[],
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    num_sanity_val_steps=0,
    check_val_every_n_epoch=None,
    callbacks=[checkpoint_callback],
)

torch.cuda.set_sync_debug_mode(1)

trainer.fit(
    train_model,
    train_dataloader,
    ckpt_path=ckpt_path,
)

Error messages and logs

# Error messages and logs here please

below some nsys traces
image
image

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100-SXM4-80GB
    • available: True
    • version: 12.1
  • Lightning:
    • gpytorch: 1.11
    • lightning: 2.2.5
    • lightning-utilities: 0.11.2
    • pytorch-lightning: 2.2.5
    • torch: 2.3.1
    • torchinfo: 1.8.0
    • torchmetrics: 1.3.1
    • torchtyping: 0.1.4
    • torchvision: 0.18.0
    • torchviz: 0.0.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.12
    • release: 5.15.0-91-generic
    • version: #101-Ubuntu SMP Tue Nov 14 13:30:08 UTC 2023

More info

No response

Hey @radomirgr
Thanks for the investigation.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

These links might have pointed to an earlier version but now they don't seem to show the place that you meant. Could you show me where in the PyTorch code this assumption is made?

I don't remember exactly why we needed the optimizer_to_device function.