marella / ctransformers

Python bindings for the Transformer models implemented in C/C++ using GGML library.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot generate text on GPU

congson1293 opened this issue · comments

I load the model to GPU like this:

llm = AutoModelForCausalLM.from_pretrained("LLM-model", 
                                            model_file="vinallama-7b-chat_q5_0.gguf",
                                            config=config, torch_dtype=torch.float16, hf=True,
                                            gpu_layers = 100, device_map='cuda')

and generate code like this:

generated_ids = llm.generate(**model_inputs,
                              max_new_tokens=4096,
                              # early_stopping=True,
                              repetition_penalty=1.1,
                              # no_repeat_ngram_size=2,
                              temperature=0.6,
                              do_sample=True,
                              # top_k=5,
                              top_p=0.9,
                              eos_token_id=tokenizer.eos_token_id,
                              pad_token_id=tokenizer.pad_token_id,
                              use_cache=True)

when I ran this code, the model got 6.6Gb on GPU. But I've got the exception: You are calling .generate() with the input_ids being on a device type different than your model's device. input_ids is on cuda, whereas the model is on cpu. when ran generate method.
Does anyone know the way to fix it?

What device is model_inputs on?

I use T4 on Google Colab.

Double check your code parts against the gold standard examples at: https://github.com/marella/ctransformers?tab=readme-ov-file#classmethod-automodelforcausallmfrom_pretrained

Do the same for the generate method of your llm class. Here is the gold standard reference for that as well: https://github.com/marella/ctransformers?tab=readme-ov-file#classmethod-automodelforcausallmfrom_pretrained

It looks like you have a lot of unnecessary arguments that you are mimicking from other libraries. Especially for the from_pretrained method call.

Hope that this helps.