KeyError in update_fn_kernel when use_triton=True
AnCoSONG opened this issue Β· comments
The way to trigger this problem in my situation is just simply define the optimizer and run and the problem occurs.
opt = Lion(
params,
lr=lr,
betas=(0.95, 0.98),
use_triton=True
)
...
...
...
trainer.fit(model, datamodule)
There are two exceptions. The first is keyerror in update_fn_kernel
:
KeyError:
('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-2d732a2488b7ed996facc3e641ee56bf-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948c
e0a49cfe122c', (torch.float32, torch.float32, torch.float32, dtype('float64'), 'fp32', 'fp32', 'fp32', 'i32'), (128,), (True, True, True, (False,), (False,), (False,), (False,), (True, False)))
During the exception above, the second exception occurs in triton/runtime/jit.py
path/python3.9/site-packages/triton/runtime/jit.py:190
β in _type_of β
β β
β 187 β β β return f'*{ty}' β
β 188 β β if key is None: β
β 189 β β β return '*i8' β
β β± 190 β β assert isinstance(key, str) β
β 191 β β return key β
β 192 β β
β 193 β def _make_signature(self, sig_key):
key
above is float64
Please give me some instruction... Thanks!
@AnCoSONG oh hey Justin, which version of triton are you on? what type of GPU do you have?
@AnCoSONG oh hey Justin, which version of triton are you on? what type of GPU do you have?
Here is the version of triton and I am using 3090
pip list | grep triton
triton 2.0.0
@AnCoSONG can i see how you defined your
lr
being passed into Lion?
weight=1.0
accumulate_grad_batches=1
ngpu=1
bs=1
base_lr=4.5e-6
lr=weight*accumulate_grad_batches * ngpu * bs * base_lr
do you have a simple script for reproducing the error?
do you have a simple script for reproducing the error?
I implemented a simple training script but the problem disappears. So I am currently tryna find out the real reason. I will update my finding in this issue later.
heisenbug