pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: cutlassF: no kernel found to launch!

goodboyyes2009 opened this issue · comments

root@md:/home/projects/gpt-fast# CUDA_VISIBLE_DEVICES=0 python3 generate.py --compile --checkpoint_path /models/huggingface_models/meta-Llama-2-7b-hf/model_int8.pth --max_new_tokens 100
Loading model ...
Using int8 weight-only quantization!
/opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
Time to load model: 2.33 seconds
Traceback (most recent call last):
File "/home/projects/gpt-fast/generate.py", line 407, in
main(
File "/home/projects/gpt-fast/generate.py", line 346, in main
y, metrics = generate(
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/projects/gpt-fast/generate.py", line 167, in generate
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
File "/home/projects/gpt-fast/generate.py", line 52, in prefill
logits = model(x, input_pos)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/projects/gpt-fast/model.py", line 118, in forward
x = layer(x, input_pos, freqs_cis, mask)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/projects/gpt-fast/model.py", line 137, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/projects/gpt-fast/model.py", line 186, in forward
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
RuntimeError: cutlassF: no kernel found to launch!

GPU: NVIDIA V100

conda list:

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
asttokens 2.0.5 pyhd3eb1b0_0
astunparse 1.6.3 pypi_0 pypi
attrs 23.1.0 pypi_0 pypi
backcall 0.2.0 pyhd3eb1b0_0
beautifulsoup4 4.12.2 py310h06a4308_0
blas 1.0 mkl
boltons 23.0.0 py310h06a4308_0
brotlipy 0.7.0 py310h7f8727e_1002
bzip2 1.0.8 h7b6447c_0
c-ares 1.19.0 h5eee18b_0
ca-certificates 2023.08.22 h06a4308_0
certifi 2023.7.22 py310h06a4308_0
cffi 1.15.1 py310h5eee18b_3
chardet 4.0.0 py310h06a4308_1003
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.0.4 py310h06a4308_0
cmake 3.26.4 h96355d8_0
conda 23.9.0 py310h06a4308_0
conda-build 3.27.0 py310h06a4308_0
conda-content-trust 0.2.0 py310h06a4308_0
conda-index 0.3.0 py310h06a4308_0
conda-libmamba-solver 23.7.0 py310h06a4308_0
conda-package-handling 2.2.0 py310h06a4308_0
conda-package-streaming 0.9.0 py310h06a4308_0
cryptography 41.0.3 py310hdda0065_0
cuda-cudart 11.8.89 0 nvidia
cuda-cupti 11.8.87 0 nvidia
cuda-libraries 11.8.0 0 nvidia
cuda-nvrtc 11.8.89 0 nvidia
cuda-nvtx 11.8.86 0 nvidia
cuda-runtime 11.8.0 0 nvidia
decorator 5.1.1 pyhd3eb1b0_0
dnspython 2.4.2 pypi_0 pypi
exceptiongroup 1.0.4 py310h06a4308_0
executing 0.8.3 pyhd3eb1b0_0
expat 2.5.0 h6a678d5_0
expecttest 0.1.6 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.9.0 py310h06a4308_0
fmt 9.1.0 hdb19cb5_0
freetype 2.12.1 h4a9f257_0
fsspec 2023.9.2 pypi_0 pypi
giflib 5.2.1 h5eee18b_3
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py310heeb90bb_0
gnutls 3.6.15 he1e5248_0
hypothesis 6.87.1 pypi_0 pypi
icu 58.2 he6710b0_3
idna 3.4 py310h06a4308_0
intel-openmp 2023.1.0 hdb19cb5_46305
ipython 8.15.0 py310h06a4308_0
jedi 0.18.1 py310h06a4308_1
jinja2 3.1.2 py310h06a4308_0
jpeg 9e h5eee18b_1
jsonpatch 1.32 pyhd3eb1b0_0
jsonpointer 2.1 pyhd3eb1b0_0
krb5 1.20.1 h143b758_1
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libarchive 3.6.2 h6ac8c49_2
libcublas 11.11.3.6 0 nvidia
libcufft 10.9.0.58 0 nvidia
libcufile 1.7.2.10 0 nvidia
libcurand 10.3.3.141 0 nvidia
libcurl 8.1.1 h251f7ec_1
libcusolver 11.4.1.48 0 nvidia
libcusparse 11.7.5.86 0 nvidia
libdeflate 1.17 h5eee18b_1
libedit 3.1.20221030 h5eee18b_0
libev 4.33 h7f8727e_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.4 h5eee18b_0
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
liblief 0.12.3 h6a678d5_0
libmamba 1.4.1 h2dafd23_1
libmambapy 1.4.1 py310h2dafd23_1
libnghttp2 1.52.0 h2d74bed_1
libnpp 11.8.0.86 0 nvidia
libnvjpeg 11.9.0.86 0 nvidia
libpng 1.6.39 h5eee18b_0
libsolv 0.7.22 he621ea3_0
libssh2 1.10.0 hdbd6064_2
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libuuid 1.41.5 h5eee18b_0
libuv 1.44.2 h5eee18b_0
libwebp 1.3.2 h11a3e52_0
libwebp-base 1.3.2 h5eee18b_0
libxml2 2.10.3 hcbfbd50_0
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.4 h6a678d5_0
markupsafe 2.1.1 py310h7f8727e_0
matplotlib-inline 0.1.6 py310h06a4308_0
mkl 2023.1.0 h213fc3f_46343
mkl-service 2.4.0 py310h5eee18b_1
mkl_fft 1.3.8 py310h5eee18b_0
mkl_random 1.2.4 py310hdb19cb5_0
more-itertools 8.12.0 pyhd3eb1b0_0
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py310h06a4308_0
ncurses 6.4 h6a678d5_0
nettle 3.7.3 hbbd107a_1
networkx 3.1 py310h06a4308_0
numpy 1.26.0 py310h5f9d8c6_0
numpy-base 1.26.0 py310hb5e798b_0
openh264 2.1.1 h4ff587b_0
openssl 3.0.11 h7f8727e_2
packaging 23.1 py310h06a4308_0
parso 0.8.3 pyhd3eb1b0_0
patch 2.7.6 h7b6447c_1001
patchelf 0.17.2 h6a678d5_0
pcre2 10.37 he7ceb23_1
pexpect 4.8.0 pyhd3eb1b0_3
pickleshare 0.7.5 pyhd3eb1b0_1003
pillow 9.4.0 py310h6a678d5_1
pip 23.2.1 py310h06a4308_0
pkginfo 1.9.6 py310h06a4308_0
pluggy 1.0.0 py310h06a4308_1
prompt-toolkit 3.0.36 py310h06a4308_0
psutil 5.9.0 py310h5eee18b_0
ptyprocess 0.7.0 pyhd3eb1b0_2
pure_eval 0.2.2 pyhd3eb1b0_0
py-lief 0.12.3 py310h6a678d5_0
pybind11-abi 4 hd3eb1b0_1
pycosat 0.6.4 py310h5eee18b_0
pycparser 2.21 pyhd3eb1b0_0
pygments 2.15.1 py310h06a4308_1
pyopenssl 23.2.0 py310h06a4308_0
pysocks 1.7.1 py310h06a4308_0
python 3.10.13 h955ad1f_0
python-etcd 0.4.5 pypi_0 pypi
python-libarchive-c 2.9 pyhd3eb1b0_1
pytorch 2.1.0 py3.10_cuda11.8_cudnn8.7.0_0 pytorch
pytorch-cuda 11.8 h7e8668a_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pytz 2023.3.post1 py310h06a4308_0
pyyaml 6.0 py310h5eee18b_1
readline 8.2 h5eee18b_0
reproc 14.2.4 h295c915_1
reproc-cpp 14.2.4 h295c915_1
requests 2.31.0 py310h06a4308_0
rhash 1.4.3 hdbd6064_0
ruamel.yaml 0.17.21 py310h5eee18b_0
ruamel.yaml.clib 0.2.6 py310h5eee18b_1
sentencepiece 0.1.99 pypi_0 pypi
setuptools 68.0.0 py310h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sortedcontainers 2.4.0 pypi_0 pypi
soupsieve 2.5 py310h06a4308_0
sqlite 3.41.2 h5eee18b_0
stack_data 0.2.0 pyhd3eb1b0_0
sympy 1.12 pypi_0 pypi
tbb 2021.8.0 hdb19cb5_0
tk 8.6.12 h1ccaba5_0
tomli 2.0.1 py310h06a4308_0
toolz 0.12.0 py310h06a4308_0
torchaudio 2.1.0 py310_cu118 pytorch
torchelastic 0.2.2 pypi_0 pypi
torchtriton 2.1.0 py310 pytorch
torchvision 0.16.0 py310_cu118 pytorch
tqdm 4.65.0 py310h2f386ee_0
traitlets 5.7.1 py310h06a4308_0
truststore 0.8.0 py310h06a4308_0
types-dataclasses 0.6.6 pypi_0 pypi
typing-extensions 4.8.0 pypi_0 pypi
typing_extensions 4.7.1 py310h06a4308_0
tzdata 2023c h04d1e81_0
urllib3 1.26.16 py310h06a4308_0
wcwidth 0.2.5 pyhd3eb1b0_0
wheel 0.41.2 py310h06a4308_0
xz 5.4.2 h5eee18b_0
yaml 0.2.5 h7b6447c_0
yaml-cpp 0.7.0 h295c915_1
zlib 1.2.13 h5eee18b_0
zstandard 0.19.0 py310h5eee18b_0
zstd 1.5.5 hc292b87_0

I have the same error

same error

My conda environment is as below:

GPU: RTX 5000
CUDA: 12.3

Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
blas 1.0 mkl
brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.08.22 h06a4308_0
certifi 2023.11.17 py311h06a4308_0
cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
charset-normalizer 2.0.4 pyhd3eb1b0_0
cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.3.101 0 nvidia
cuda-runtime 12.1.0 0 nvidia
ffmpeg 4.2.2 h20bf706_0
filelock 3.9.0 py311_0 pytorch-nightly
freetype 2.12.1 h4a9f257_0
fsspec 2023.12.2 pypi_0 pypi
giflib 5.2.1 h5eee18b_3
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py311hc9b5ff0_0
gnutls 3.6.15 he1e5248_0
huggingface-hub 0.19.4 pypi_0 pypi
idna 3.4 py311h06a4308_0
intel-openmp 2021.4.0 h06a4308_3561
jinja2 3.1.2 py311h06a4308_0
jpeg 9e h5eee18b_1
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.8.1.2 0 nvidia
libcurand 10.3.4.101 0 nvidia
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libdeflate 1.17 h5eee18b_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libidn2 2.3.4 h5eee18b_0
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
libnpp 12.0.2.50 0 nvidia
libnvjitlink 12.1.105 0 nvidia
libnvjpeg 12.1.1.14 0 nvidia
libopus 1.3.1 h7b6447c_0
libpng 1.6.39 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libuuid 1.41.5 h5eee18b_0
libvpx 1.7.0 h439df22_0
libwebp 1.2.4 h11a3e52_1
libwebp-base 1.2.4 h5eee18b_1
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.4 h6a678d5_0
markupsafe 2.1.1 py311h5eee18b_0
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly
mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly
mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.2.1 py311_0 pytorch-nightly
ncurses 6.4 h6a678d5_0
nettle 3.7.3 hbbd107a_1
networkx 3.1 py311h06a4308_0
numpy 1.24.3 py311hc206e33_0
numpy-base 1.24.3 py311hfd5febd_0
openh264 2.1.1 h4ff587b_0
openssl 3.0.12 h7f8727e_0
packaging 23.2 pypi_0 pypi
pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly
pip 23.3.1 py311h06a4308_0
pycparser 2.21 pyhd3eb1b0_0
pyopenssl 23.2.0 py311h06a4308_0
pysocks 1.7.1 py311_0 pytorch-nightly
python 3.11.5 h955ad1f_0
pytorch 2.3.0.dev20231214 py3.11_cuda12.1_cudnn8.9.2_0 pytorch-nightly
pytorch-cuda 12.1 ha16c6d3_5 pytorch-nightly
pytorch-mutex 1.0 cuda pytorch-nightly
pyyaml 6.0.1 py311h5eee18b_0
readline 8.2 h5eee18b_0
requests 2.28.1 py311_0 pytorch-nightly
sentencepiece 0.1.99 pypi_0 pypi
setuptools 68.2.2 py311h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sqlite 3.41.2 h5eee18b_0
sympy 1.12 py311h06a4308_0
tk 8.6.12 h1ccaba5_0
torchaudio 2.2.0.dev20231214 py311_cu121 pytorch-nightly
torchtriton 2.1.0+bcad9dabe1 py311 pytorch-nightly
torchvision 0.18.0.dev20231214 py311_cu121 pytorch-nightly
tqdm 4.66.1 pypi_0 pypi
typing_extensions 4.7.1 py311h06a4308_0
tzdata 2023c h04d1e81_0
urllib3 1.26.14 py311_0 pytorch-nightly
wheel 0.41.2 py311h06a4308_0
x264 1!157.20191217 h7b6447c_0
xz 5.4.5 h5eee18b_0
yaml 0.2.5 h7b6447c_0
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0

Can you try using a the patch release, or nightly?

@drisspg #46 (comment) @merveermann says they're using the nightly I believe.

So this error is being thrown on Nightly for devices: V100, RTX5000
Is there any others?

Also it is possible to give example inputs of to SDPA that are causing this error to be thrown?
Is this only happening when the model is being compiled?

My hunch is that compile is doing some memory planning optimizations that cause the alignment check here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L1023-L1027
to fail for all possible kernels.

It seems your GPU not support bf16, change all torch.bfloat16 to torch.float32 may work.

@drisspg I tested on a V100. Both eager and compiled runs into the same error.

I think the issue is that mem_eff_attention doesn't support bf16 on sm < 80: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h#L286

I tested with float16 and it works. Shall we default gpt-fast to float16 for V100 and under?

Ohh @yifuwang thank you, that is a great catch I will put up a PR right now to fix this in PyTorch

thank you all, after change all torch.bfloat16 to torch.float32, run with unquantized model works well
but run with int8 seems wrong

root@md:/home/projects/gpt-fast# CUDA_VISIBLE_DEVICES=0 python3 generate.py --compile --checkpoint_path /models/huggingface_models/meta-Llama-2-7b-hf/model_int8.pth --max_new_tokens 100
Loading model ...
Using int8 weight-only quantization!
/opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Time to load model: 2.52 seconds
[2023-12-19 00:54:26,247] [0/0] torch._dynamo.output_graph: [WARNING] nn.Module state_dict and backward hooks are not yet supported by torch.compile, but were detected in your model and will be silently ignored. See https://pytorch.org/docs/master/compile/nn-module.html for more information and limitations.
Compilation time: 101.21 seconds
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 1: 4.87 sec total, 20.53 tokens/sec
Bandwidth achieved: 141.08 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 2: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.25 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 3: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.24 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 4: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.22 GB/s
Hello, my name is ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Time for inference 5: 4.87 sec total, 20.55 tokens/sec
Bandwidth achieved: 141.24 GB/s
==========
Average tokens/sec: 20.55
Memory used: 8.01 GB

@goodboyyes2009 Did you re-run quatilized.py after torch.bfloat16 to torch.float32?

@VendaCino oh, sorry, I do re-run quatilized.py, but I change all torch.bfloat16 to torch.float16

'⁇ ⁇ ⁇' is because tensor value nan.
I debug found that the kv_cache in that attention layer is nan.
and this issue will not happen when all dtype is torch.float32 but not torch.float16
and this issue not happen when I use tinyllama but not viucna-7b.

image

hope this information can help to trace the problem.

update:
deep debug found that it is because x.max() = inf
I think some layer output too large and float16 not ok to show that.

Time to load model: 1.97 seconds
tensor(7.0664, device='cuda:0', dtype=torch.float16)
tensor(18.3906, device='cuda:0', dtype=torch.float16)
tensor(inf, device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16)

it depends on the weight of model, so when I test in tinyllama it works well.

when I use model.pth

Time to load model: 10.10 seconds
tensor(7.0625, device='cuda:0', dtype=torch.float16)
tensor(18.3438, device='cuda:0', dtype=torch.float16)
tensor(1532., device='cuda:0', dtype=torch.float16)

so i guess something wrong in WeightOnlyInt8Linear

class WeightOnlyInt8Linear(torch.nn.Module):
    ...

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales -> here loss the precision

change it to

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (F.linear(input.to(dtype=torch.float32), self.weight.to(dtype=torch.float32)) * self.scales).to(dtype=input.dtype)

everything looks good

Time to load model: 1.66 seconds
tensor(7.0664, device='cuda:0', dtype=torch.float16)
tensor(18.3906, device='cuda:0', dtype=torch.float16)
tensor(1535., device='cuda:0', dtype=torch.float16)

OK. Thank you very much! @VendaCino