Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

reset_cache() Decrease the Generation Quality of Consecutive Inferences

HenryPengZou opened this issue · comments

When conducting generation for multiple consecutive inputs on a LoRA fine-tuned LLaMA, I noticed that using 'reset_cache' after each generation for one input will affect the performance of generation on the next input. However, if you load the model again after each generation, the performance stays good. But reloading consumes lots of time. Could you help provide some explanation why 'reset_cache' will decrease the performance of the generation on the next consecutive inputs?

Code: I modified the code of 'generate/lora.py' to enable consecutive generation on multiple inputs. Basically, just add a for loop and model.reset_cache()

  # support multiple inference 
  outputs = []
  num_samples = len(input)
  for i in range(num_samples):
      sample = {"instruction": prompt[i], "input": input[i]}
      prompt = generate_prompt(sample)
      encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

      t0 = time.perf_counter()
      output = generate(
          model,
          idx=encoded,
          max_new_tokens=max_new_tokens,
          temperature=temperature,
          top_k=top_k,
          eos_id=tokenizer.eos_id
      )
      t = time.perf_counter() - t0

      model.reset_cache()
      output = tokenizer.decode(output)
      output = output.split("### Response:")[1].strip()
      print(output)
      print(f"Time for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
      outputs.append(output)