Autocast "cache_enabled=True" failing
thomassajot opened this issue · comments
Bug description
The autocast argument cache_enabled=True
is actually not caching the layer weights when using a Trainer.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
from pathlib import Path
import pytorch_lightning as pl
import torch
from pytorch_lightning.profilers import PyTorchProfiler
TRACE_DIR = Path("~/traces").expanduser()
AUTOCAST_TO = torch.float16
DEVICE = "cuda:1"
class Module(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1000, 1000, bias=True)
self.l2 = torch.nn.Linear(1000, 100, bias=True)
def forward(self, x):
return self.l2(self.l1(x))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
x = torch.randn(2000, 1000, device=DEVICE, dtype=torch.float32)
y = torch.randn(2000, 100, device=DEVICE, dtype=torch.float32)
dl = torch.utils.data.DataLoader(list(zip(x, y)), batch_size=32)
model = Module()
schedule = torch.profiler.schedule(wait=6, warmup=2, active=4, repeat=2)
profiler = PyTorchProfiler(
schedule=schedule,
dirpath=str(TRACE_DIR),
filename="lightning_autocast",
sort_by_key="cuda_time",
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=True,
row_limit=100,
)
trainer = pl.Trainer(accelerator="cuda", precision=16, devices=[1], profiler=profiler, max_steps=40)
trainer.fit(model, dl)
Error messages and logs
The above training scripts produces the following trace, where there are 3 calls to aten:to
before the linear layer (one for the input, weight and bias). The second linear layer has only 2 calls to aten:to
as the input is already in the right dtype.
What should be expected is one (or 0) call to aten:to
as the weights should be cached into the right dtype, example:
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
Looking at the code base, autocast
is used with its default value to cache_enabled=True
. Not sure why the cache would not be used.
Reading further the Pytorch doc:
autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.
The actual implementation of pytorch-lightning is actually accurate.
However, it would be great to find a way to disable the autocast during the backward pass rather than re-initialising autocast at every forward.