OpenNMT / CTranslate2

Fast inference engine for Transformer models

Home Page:https://opennmt.net/CTranslate2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature Request: Implement Static Cache and Quantization Techniques in CTranslate2

Jiltseb opened this issue · comments

@minhthuc2502 @alexlnkp

Description
What type of cache is currently implemented in CTranslate2? Is it static or dynamic? Could we achieve a speed-up if the cache implementation is changed for the decoder in encoder-decoder models?

Also, it would be great to implement recent popular quantization techniques such as [HQQ] (https://github.com/mobiusml/hqq) in ctranslate2 format.

Motivation
Given that a static cache (see this PR) can significantly speed up processing in PyTorch encoder-decoder models via torch compilation, can we enable this in CTranslate2? This enhancement can improve decoding speed for projects utilizing CTranslate2 models, such as Faster Whisper.

References
Speed-up achieved for PyTorch-based Whisper: Blog Post

Benefits
Implementing static caching and recent quantization techniques in CTranslate2 could lead to significant performance improvements in model decoding speeds and efficiency.

Thank you for considering this feature request!

Hello, thank you for your information. I will take a look to know how hqq works. Otherwise, the cache's implementation in Ctranslate2 is reallocated depending on the length of sequence. I'm not sure if the static cache (as my understanding, the cache will be pre-allocated before with max size) can speed up much with Ctranslate2, it requires some benchmark to confirm it. Some changes in design is needed.

Hello, thank you for your information. I will take a look to know how hqq works. Otherwise, the cache's implementation in Ctranslate2 is reallocated depending on the length of sequence. I'm not sure if the static cache (as my understanding, the cache will be pre-allocated before with max size) can speed up much with Ctranslate2, it requires some benchmark to confirm it. Some changes in design is needed.

@minhthuc2502 Thanks. This blog post should help with HQQ: https://mobiusml.github.io/hqq_blog/

Hello, I tried to implement the HQQ for 4 bit quantization, it works but I think it has to combine low-bit matmul kernel to speed up the inference, Do you have any reference for this?

As I understand it, to use int4mm, you have to convert HQQ quant to the format accepted by int4mm (w_q and scale,,,). Do you think it will reduce the performance? Similar to using int4mm kernel, It looks like when we dequantize the weight and make the gemm, and the performance is not as expected. I see you have very good performance with torch.compile but in Ctranslate2, we don't have this kind of solution. I will investigate more about others solutions and the kernel along with them.

The conversion is only done once, different int4 kernels require different input formats, so that's why we do it via patching so we can support many backends, not just the torchao int4 kernel.

The int4mm kernel should be faster than fp16 matmul with or without torch.compile. However, it is a gemv kernel optimized for decoding only, so the prefill phase with this kernel is actually slower. That's why we don't use it in the encoder, and only use it in the decoding phase. Unfortunately, there's no way so far to efficiently dequantize() and do fp16 matmul in the prefill phase which should be faster. But decoding one-token at a time with the int4mm is def faster.

There are other options by the way, like https://github.com/microsoft/BitBLAS/ , they also support A16W2 matmul which should be even faster for larger models.

@minhthuc2502 Did you figure out how to speed up the HQQ implementation in ctranslate2? This will be a useful add-on for large E-D models.

I tried to implement HQQ, only quantization, but I have not do the benchmark yet. The thing which prevented me to make it work correctly is patching the HQQ format. As I understand I will patch the HQQ format in the conversion time and then do the inference with the new format. BTW, I'll try to implement it.

I tried hqq quant in this PR . When I use int4mm, the quality is not good and it becomes worse with long prompt. I tested only with decoder. Do you have any idea if I did any mistake in the conversion.

I tried using hqq quant first and then converting directly to torch quant format to have scales and zeros and weights. Then before the inference, I will convert the weight to int4pack and then use the int4mm for matmul.

@minhthuc2502 you have to quantize with hqq. The link you shared is just doing RTN quantization, which will give bad quality especially at lower bits.

I followed this to add the hqq https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py. Did I miss something?

@minhthuc2502 yes, it's not performing optimization: https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py#L115-L122 , which is the actual hqq algo: https://github.com/mobiusml/hqq/blob/master/hqq/core/optimize.py#L194-L243
Basically, you get an initial estimate of the quantized weights/scale/zero (which is what you did), but then you fine-tune them via alternate minimization.
Any reason why you re-implement things in numpy ? You could simply call HQQLinear from the hqq lib, it will do everything for you. Doing this on cpu will be very slow, better do it on the GPU, example below:

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from hqq.backends.torchao import patch_hqq_to_aoint4
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)
hqq_ao_layer = patch_hqq_to_aoint4(HQQLinear(your_linear_layer, quant_config=quant_config, compute_dtype=torch.bfloat16, device='cuda'), None)

out = hqq_ao_layer.forward(x)

Thank you for your suggestion. I just use directly the HQQLinear and the quality is better due to the optimization. See in the PR. When I tested with long prompt (tried to summarize a long text), it did not work good, there is some repetitions or worse. What do you think about it? Is it normal due to the lower bit quantization?

Glad to hear it worked better! It should work fine with 4-bit and a group-size of 64 as suggested in the code above. Which model did you try the summarization with? Do you have a code snippet I can run to investigate ?

I tested with LLama2. You can do a simple generation with llama2 and make a prompt like this:
"Summarize this paragraph: Roger Federer (born 8 August 1981) is a Swiss former professional tennis player. Federer was ranked world No. 1 in singles by the Association of Tennis Professionals (ATP) for 310 weeks, including a record 237 consecutive weeks, and finished as the year-end No. 1 five times. He won 103 singles titles on the ATP Tour, the second most of all time, including 20 major men's singles titles (among which a record eight men's singles Wimbledon titles, and an Open Era joint-record five men's singles US Open titles) and six year-end championships.".

The result i got:
Roger Federer was a professional tennis player who spent a total of 40 weeks, including a record 20 consecutive times, and finished as the year-end No. 1 player in the ATP rankings. He won 103 singles titles on the ATP Tour, including the 20 major men's titles (ATP rankings rankings (among all players (40 ATP-ranked players (23 ATP-ranked players (103 players (103 times (23 times (10 times (10 times Roger Federer (10 times (10 times (10 times (10 times Roger Federer ATP Tour Roger Federer (1981 (1990 (190 times (190 times Roger Federer Federer Federer (190 times Roger Federer Federer Federer (190 times Roger Federer Federer (190 times Roger Federer world No. 1999999 ATP-ranked players (10319 ATP-ranked players (10319 ATP-ranked players (1901 ATP-ranked players (103 ATP-ranked ATP-rank players (100 ATP- (103- (103- players- players ATP- players- players ATP- (103- (1901- (103- (103- players- (103- (103- (103- (103- (103- (103- (103- players- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103- (103-

I applied int4mm for prefill step + generation step

I tried with Llama2-7B and it's working fine:

import torch, os
cache_path     = '.'
model_id       = "meta-llama/Llama-2-7b-chat-hf"
compute_dtype  = torch.bfloat16 #int4 kernel only works with bfloat16
device         = 'cuda:0'

##########################################################################################################################################################
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.core.quantize import *

#Load
tokenizer    = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path)

#No quantize
# from transformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map = device)

#Quantize
model        = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)

#Set default backends, to compare with int4mm
if(quant_config['weight_quant_params']['axis']==0):
    HQQLinear.set_backend(HQQBackend.ATEN)
else:
    HQQLinear.set_backend(HQQBackend.PYTORCH)

##########################################################################################################################################################
from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model, backend="torchao_int4")

#Import custom HF generator
from hqq.utils.generation_hf import HFGenerator

#Generate
#gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile="partial").warmup() 
gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=False, compile=None)

prompt = "Summarize this paragraph: Roger Federer (born 8 August 1981) is a Swiss former professional tennis player. Federer was ranked world No. 1 in singles by the Association of Tennis Professionals (ATP) for 310 weeks, including a record 237 consecutive weeks, and finished as the year-end No. 1 five times. He won 103 singles titles on the ATP Tour, the second most of all time, including 20 major men's singles titles (among which a record eight men's singles Wimbledon titles, and an Open Era joint-record five men's singles US Open titles) and six year-end championships."
out = gen.generate(prompt, print_tokens=True)

Outputs:

FP16:
Roger Federer is a former professional tennis player from Switzerland, born on August 8, 1981. He held the number one ranking in singles by the ATP for 310 weeks, including a record 237 consecutive weeks, and finished as the year-end number one five times. Federer won 103 singles titles on the ATP Tour, including 20 major men's singles titles and six year-end championships.
HQQ 4-bit (axis=1) - int4mm:
Roger Federer is a former professional tennis player from Switzerland, born on August 8, 1981. He was ranked as the world's number one singles player by the ATP for 310 weeks, including a record 237 consecutive weeks, and finished as the year-end number one five times. During his career, Federer won 103 singles titles on the ATP Tour, including 20 major men's singles titles, and six year-end championships.

It's weird. It seems like I used exactly the parameters for hqq quantization + int4mm.