lucidrains / lion-pytorch

🦁 Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

KeyError in update_fn_kernel when use_triton=True

AnCoSONG opened this issue Β· comments

The way to trigger this problem in my situation is just simply define the optimizer and run and the problem occurs.

opt = Lion(
    params,
    lr=lr,
    betas=(0.95, 0.98),
    use_triton=True
)
...
...
...
trainer.fit(model, datamodule)

There are two exceptions. The first is keyerror in update_fn_kernel:

KeyError: 
('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-2d732a2488b7ed996facc3e641ee56bf-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948c
e0a49cfe122c', (torch.float32, torch.float32, torch.float32, dtype('float64'), 'fp32', 'fp32', 'fp32', 'i32'), (128,), (True, True, True, (False,), (False,), (False,), (False,), (True, False)))

During the exception above, the second exception occurs in triton/runtime/jit.py

 path/python3.9/site-packages/triton/runtime/jit.py:190 
β”‚ in _type_of                                                                                      β”‚
β”‚                                                                                                  β”‚
β”‚   187 β”‚   β”‚   β”‚   return f'*{ty}'                                                                β”‚
β”‚   188 β”‚   β”‚   if key is None:                                                                    β”‚
β”‚   189 β”‚   β”‚   β”‚   return '*i8'                                                                   β”‚
β”‚ ❱ 190 β”‚   β”‚   assert isinstance(key, str)                                                        β”‚
β”‚   191 β”‚   β”‚   return key                                                                         β”‚
β”‚   192 β”‚                                                                                          β”‚
β”‚   193 β”‚   def _make_signature(self, sig_key):

key above is float64

Please give me some instruction... Thanks!

@AnCoSONG oh hey Justin, which version of triton are you on? what type of GPU do you have?

@AnCoSONG oh hey Justin, which version of triton are you on? what type of GPU do you have?

Here is the version of triton and I am using 3090

pip list | grep triton
triton                   2.0.0

@AnCoSONG can i see how you defined your lr being passed into Lion?

@AnCoSONG can i see how you defined your lr being passed into Lion?

weight=1.0
accumulate_grad_batches=1
ngpu=1
bs=1
base_lr=4.5e-6

lr=weight*accumulate_grad_batches * ngpu * bs * base_lr

do you have a simple script for reproducing the error?

do you have a simple script for reproducing the error?

I implemented a simple training script but the problem disappears. So I am currently tryna find out the real reason. I will update my finding in this issue later.

heisenbug