Cannot run the triton kernels
jmercat opened this issue · comments
Thanks for this repo, I'm pretty excited to test this out.
I drop-in replaced attention from lightning-attention in one of my projects and got the following:
RuntimeError: PassManager::run failed
Traceback (most recent call last):
File "/opt/ml/code/open_lm/main.py", line 873, in <module>
main(sys.argv[1:])
File "/opt/ml/code/open_lm/main.py", line 774, in main
success, global_step = train_one_epoch(
File "/opt/ml/code/open_lm/train.py", line 267, in train_one_epoch
backward(local_loss, scaler)
File "/opt/ml/code/open_lm/train.py", line 92, in backward
total_loss.backward()
File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/lightning-attention/lightning_attn/ops/triton/lightning_attn2.py", line 462, in backward
_bwd_intra_kernel[grid](
File "<string>", line 63, in _bwd_intra_kernel
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 383, in <lambda>
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir
pm.run(mod)
So I tried to simply run pytest tests/ops/test_lightning2.py
And got only failures (it is weird that there is an assert False
statement in there...)
And the more worrisome result is that the errors are quite large...
tests/ops/test_lightning2.py FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [100%]
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-256-128-64] tensor(0.1543, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.1650, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
tensor(0.1641, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-512-128-64] tensor(0.2393, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2520, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2539, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-1024-128-64] tensor(0.3555, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3770, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3750, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-128-64] tensor(0.5117, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5430, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-4096-128-64] tensor(0.7344, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7773, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7734, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-8192-128-64] tensor(1.0391, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
tensor(1.1016, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-32-64] tensor(0.2578, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
tensor(0.2715, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-8-2048-64-64] tensor(0.3633, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3848, device='cuda:0', dtype=torch.bfloat16)
tensor(0.3828, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-12-2048-128-64] tensor(0.5234, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
tensor(0.5547, device='cuda:0', dtype=torch.bfloat16)
FAILED
tests/ops/test_lightning2.py::test_lightning2[dtype0-6-16-2048-128-64] tensor(0.6719, device='cuda:0', dtype=torch.bfloat16,
grad_fn=<LinalgVectorNormBackward0>)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7148, device='cuda:0', dtype=torch.bfloat16)
tensor(0.7109, device='cuda:0', dtype=torch.bfloat16)
FAILED
Hi, thank you for providing the information.
It seems that the issue with the first question is most likely related to the version. The locally tested version that works fine is as follows.
╰─± pip list | grep triton
triton 2.0.0
triton-nightly 2.1.0.dev20230728172942
You can use the following command to install the package:
pip install triton==2.0.0
pip install triton-nightly==2.1.0.dev20230728172942 --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
As for the second question, you can temporarily ignore it. Due to the inherent issues with Triton, numerical errors cannot be avoided. However, we have trained models using this kernel and compared them to the baseline (torch version), and there is almost no difference in loss. So, you can use it with confidence.
If you encounter any other issues, feel free to ask at any time.