multi GPU inference
hensiesp32 opened this issue · comments
hi, thanks for your wonderful work. I was wondering if i want to infer a model with multi-gpu,what should i do? I have tried with belowing code when load model with device_map
parameter:
model = AutoModelForCausalLM.from_pretrained(model_name, config = config,
torch_dtype = torch.float16,device_map = 'auto',
trust_remote_code = True)
but there are some troubles when run the code in the https://github.com/spcl/QuaRot/blob/main/fake_quant/monkeypatch.py
wrapper = monkeypatch.add_wrapper_after_function_call_in_method(module, "forward",
function_name, functools.partial(QKRotationWrapper, *args, **kwargs))
setattr(module, attr_name, wrapper)
because the module type have changed functool.partial
. Are there any methods to solve the problem? Or how do you use QuaRot when the model is too large(such as llama-2-70b) to run on single gpu.
Thank you so much for your issue.
Yes, this is the hack we applied during the submission. I would suggest to use this branch for multi-gpu inference.