autocast to float16/bfloat16 fails on transformer encoder
AmitMY opened this issue · comments
Bug description
bf16
precision in Trainer yields an error
What version are you seeing the problem on?
v2.3
How to reproduce the bug
My model includes this encoder:
self.encoder = nn.Sequential(
nn.Flatten(start_dim=2),
nn.Dropout(0.15),
nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
PositionalEncoding(d_model=hidden_dim),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True),
num_layers=num_layers
)
)
Then, run the Trainer with precision="bf16-mixed"
(Note! "bf16-true" works, but yields a very bad learning curve)
Error messages and logs
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/train.py", line 147, in <module>
main()
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/train.py", line 143, in main
trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=validation_dataset)
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
self._run_sanity_check()
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
val_loop.run()
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
return loop_run(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 411, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 234, in validation_step
loss, prediction = self.step(batch)
^^^^^^^^^^^^^^^^
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 215, in step
x_hat, indices = self(x)
^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 170, in forward
return self.model(batch)
^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 129, in forward
x = self.encoder(x)
^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 391, in forward
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 685, in forward
return torch._transformer_encoder_layer_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None - Lightning:
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.3.0
- torch: 2.2.2
- torchmetrics: 1.4.0.post0
- vector-quantize-pytorch: 1.14.24 - Packages:
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- astroid: 3.2.2
- attrs: 23.2.0
- certifi: 2024.6.2
- charset-normalizer: 3.3.2
- click: 8.1.7
- datasets: 2.20.0
- decorator: 4.4.2
- dill: 0.3.8
- docker-pycreds: 0.4.0
- einops: 0.8.0
- einx: 0.3.0
- filelock: 3.15.1
- frozendict: 2.4.4
- frozenlist: 1.4.1
- fsspec: 2024.5.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- huggingface-hub: 0.23.3
- idna: 3.7
- imageio: 2.34.1
- imageio-ffmpeg: 0.5.1
- iniconfig: 2.0.0
- isort: 5.13.2
- jinja2: 3.1.4
- lightning-utilities: 0.11.2
- markupsafe: 2.1.5
- mccabe: 0.7.0
- moviepy: 1.0.3
- mpmath: 1.3.0
- multidict: 6.0.5
- multiprocess: 0.70.16
- networkx: 3.3
- numpy: 1.26.4
- opencv-python: 4.10.0.82
- packaging: 24.1
- pandas: 2.2.2
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- pluggy: 1.5.0
- pose-format: 0.4.1
- proglog: 0.1.10
- protobuf: 5.27.1
- psutil: 5.9.8
- pyarrow: 16.1.0
- pyarrow-hotfix: 0.6
- pylint: 3.2.3
- pytest: 8.2.2
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.3.0
- pytz: 2024.1
- pyyaml: 6.0.1
- requests: 2.32.3
- scipy: 1.13.1
- sentry-sdk: 2.5.1
- setproctitle: 1.3.3
- setuptools: 69.5.1
- sign-vq: 0.0.1
- six: 1.16.0
- smmap: 5.0.1
- sympy: 1.12.1
- tomlkit: 0.12.5
- torch: 2.2.2
- torchmetrics: 1.4.0.post0
- tqdm: 4.66.4
- typing-extensions: 4.12.2
- tzdata: 2024.1
- urllib3: 2.2.1
- vector-quantize-pytorch: 1.14.24
- wandb: 0.17.1
- wheel: 0.43.0
- xxhash: 3.4.1
- yarl: 1.9.4 - System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: i386
- python: 3.11.9
- release: 23.5.0
- version: Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000
More info
I tried to follow #15006
and feed the batch directly as bf16
. that does not change the error
@AmitMY The Trainer applies the PyTorch autocast context manager over the forward and converts the inputs. Take a look at the error traceback, see the line
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
and then from there work out which tensors (output, weights of TransformerEncoder) mismatch the dtype. It's possible that the input tensor here is the output of the previous layer (e.g. PositionalEncoding) and the dtype mismatch needs to be fixed there.
If there is reason to believe something is not done right in Lightning, please provide a reproducible example. Thanks!
The reason I believe it is a problem with pytorch-lightning
is that using normal torch autocasting works fine:
def test_training_step_bfloat16_expected_loss_finite(self):
batch = MaskedTensor(torch.full((4, 3, *self.pose_dim), fill_value=2, dtype=torch.float))
model = self.model_setup()
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
loss = model.training_step(batch)
self.assertNotEqual(0, float(loss))
self.assertTrue(torch.isfinite(loss))
As for the input to the transformer: both in torch autocast and lightning, I see:
dtype in PositionalEncoding torch.bfloat16
dtype out PositionalEncoding torch.float32
If I remove that layer, it still crashes with the same error.
Minimal repro:
import math
import pytorch_lightning as pl
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, IterableDataset
class PoseFSQAutoEncoder(nn.Module):
# pylint: disable=too-many-arguments
def __init__(self,
pose_dims: tuple = (178, 3),
hidden_dim=512,
nhead=16,
dim_feedforward=2048,
num_layers=6):
super().__init__()
self.encoder = nn.Sequential(
nn.Flatten(start_dim=2),
nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True),
num_layers=num_layers
)
)
def forward(self, batch: Tensor):
return self.encoder(batch)
class AutoEncoderLightningWrapper(pl.LightningModule):
def __init__(self, model: PoseFSQAutoEncoder,
learning_rate: float = 3e-4,
warmup_steps: int = 10000):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
def forward(self, batch):
return self.model(batch)
def configure_optimizers(self):
# Optimizer taken from https://arxiv.org/pdf/2307.09288.pdf
return torch.optim.AdamW(self.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1)
def step(self, x: Tensor):
x_hat, indices = self(x)
# fake loss, for repro
return 0
def training_step(self, batch, *args, **kwargs):
loss, _ = self.step(batch)
return loss
def validation_step(self, batch, batch_idx, *args, **kwargs):
loss, prediction = self.step(batch)
return loss
class FakeDataset(IterableDataset):
def __iter__(self):
while True:
yield torch.randn(size=(10, 178, 3))
auto_encoder = PoseFSQAutoEncoder()
model = AutoEncoderLightningWrapper(auto_encoder)
train_dataset = DataLoader(FakeDataset(),
batch_size=2,
num_workers=0)
validation_dataset = DataLoader(FakeDataset(),
batch_size=2,
shuffle=False,
num_workers=0)
precision = "bf16-mixed"
trainer = pl.Trainer(max_steps=100000,
val_check_interval=100_000 // 2,
precision=precision,
)
trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=validation_dataset)