Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How can I do to inferece with different promts in Juypter Notebook, just load the model and tokenizer once?

Vinter8848 opened this issue · comments

commented
How can I do to inferece with different promts in Juypter Notebook, just load the model and tokenizer once?
commented

Here's the problem that I met

def out(prompt1,in_put1,tokenizer1,model1):
    sample = {"instruction": prompt1, "input": in_put1}
    prompt = generate_prompt(sample)
    encoded = tokenizer1.encode(prompt, bos=True, eos=False, device=model1.device)


    output = generate(
            model,
            idx=encoded,
            max_new_tokens=50,
            temperature=0.8,
            top_k=100,
            eos_id=tokenizer1.eos_id
        )


    output = tokenizer1.decode(output)
    output = output.split("### Response:")[1].strip()
    return output
precision = "bf16-true" #if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
fabric = L.Fabric(devices=1, precision=precision)

print("Loading model ...", file=sys.stderr)
t0 = time.time()

with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
    name = llama_model_lookup(pretrained_checkpoint)
    
    with fabric.init_module(empty_init=True), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
        model = LLaMA.from_name(name)

     # 1. Load the pretrained weights
        model.load_state_dict(pretrained_checkpoint, strict=False)
    # 2. Load the fine-tuned lora weights
        model.load_state_dict(lora_checkpoint, strict=False)

print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup(model)
tokenizer = Tokenizer(tokenizer_path)
prompt0 = "What is the SI unit of the physical quantity m/Q?"
in_put0 = "(A)Meter per second(B)Pascal per second(C)Kilogram per coulomb(D)Newton per meter(E)Joule per second"
print(out(prompt0,in_put0,tokenizer,model))

C

prompt1 = "What is the SI unit of the physical quantity m/Q?"
in_put1 = "(A)Meter per second(B)Pascal per second(D)Newton per meter(E)Joule per second"
print(out(prompt1,in_put1,tokenizer,model))

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in :1 │
│ │
│ ❱ 1 print(out(prompt1,in_put1,tokenizer,model)) │
│ 2 │
│ │
│ in out:7 │
│ │
│ 4 │ encoded = tokenizer1.encode(prompt, bos=True, eos=False, device=model1.device) │
│ 5 │ │
│ 6 │ │
│ ❱ 7 │ output = generate( │
│ 8 │ │ │ model, │
│ 9 │ │ │ idx=encoded, │
│ 10 │ │ │ max_new_tokens=50, │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /kaggle/working/../input/litllama/litllama/generate.py:65 in generate │
│ │
│ 62 │ │ x = idx.index_select(0, input_pos).view(1, -1) │
│ 63 │ │ │
│ 64 │ │ # forward │
│ ❱ 65 │ │ logits = model(x, max_seq_length, input_pos) │
│ 66 │ │ logits = logits[0, -1] / temperature │
│ 67 │ │ │
│ 68 │ │ # optionally crop the logits to only the top k options │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /opt/conda/lib/python3.10/site-packages/lightning/fabric/wrappers.py:117 in forward │
│ │
│ 114 │ │ args, kwargs = self._precision.convert_input((args, kwargs)) │
│ 115 │ │ │
│ 116 │ │ with self._precision.forward_context(): │
│ ❱ 117 │ │ │ output = self._forward_module(*args, **kwargs) │
│ 118 │ │ │
│ 119 │ │ output = self._precision.convert_output(output) │
│ 120 │ │ return output │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /kaggle/working/../input/litllama/litllama/litllama/model.py:114 in forward │
│ │
│ 111 │ │ │ │ │ for _ in range(self.config.n_layer) │
│ 112 │ │ │ │ ] │
│ 113 │ │ │ for i, block in enumerate(self.transformer.h): │
│ ❱ 114 │ │ │ │ x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, s │
│ 115 │ │ │
│ 116 │ │ x = self.transformer.ln_f(x) │
│ 117 │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /kaggle/working/../input/litllama/litllama/litllama/model.py:163 in forward │
│ │
│ 160 │ │ input_pos: Optional[torch.Tensor] = None, │
│ 161 │ │ kv_cache: Optional[KVCache] = None, │
│ 162 │ ) -> Tuple[torch.Tensor, Optional[KVCache]]: │
│ ❱ 163 │ │ h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos │
│ 164 │ │ x = x + h │
│ 165 │ │ x = x + self.mlp(self.rms_2(x)) │
│ 166 │ │ return x, new_kv_cache │
│ │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /kaggle/working/../input/litllama/litllama/litllama/model.py:228 in forward │
│ │
│ 225 │ │ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) │
│ 226 │ │ │
│ 227 │ │ # efficient attention using Flash Attention CUDA kernels │
│ ❱ 228 │ │ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) │
│ 229 │ │ │
│ 230 │ │ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs │
│ 231 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: The size of tensor a (155) must match the size of tensor b (145) at non-singleton dimension 3

commented

invoke the function in lit_llama/model.py can solve the problem

reset_cache( )