spcl / QuaRot

Code for QuaRot, an end-to-end 4-bit inference of large language models.

Home Page:https://arxiv.org/abs/2404.00456

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

multi GPU inference

hensiesp32 opened this issue · comments

hi, thanks for your wonderful work. I was wondering if i want to infer a model with multi-gpu,what should i do? I have tried with belowing code when load model with device_map parameter:

model = AutoModelForCausalLM.from_pretrained(model_name, config = config,
                                                        torch_dtype = torch.float16,device_map = 'auto',
                                                        trust_remote_code = True)

but there are some troubles when run the code in the https://github.com/spcl/QuaRot/blob/main/fake_quant/monkeypatch.py

wrapper = monkeypatch.add_wrapper_after_function_call_in_method(module, "forward",
                                            function_name, functools.partial(QKRotationWrapper, *args, **kwargs))
setattr(module, attr_name, wrapper)

because the module type have changed functool.partial. Are there any methods to solve the problem? Or how do you use QuaRot when the model is too large(such as llama-2-70b) to run on single gpu.

@hensiesp32

Thank you so much for your issue.

Yes, this is the hack we applied during the submission. I would suggest to use this branch for multi-gpu inference.