reset_cache() Decrease the Generation Quality of Consecutive Inferences
HenryPengZou opened this issue · comments
When conducting generation for multiple consecutive inputs on a LoRA fine-tuned LLaMA, I noticed that using 'reset_cache' after each generation for one input will affect the performance of generation on the next input. However, if you load the model again after each generation, the performance stays good. But reloading consumes lots of time. Could you help provide some explanation why 'reset_cache' will decrease the performance of the generation on the next consecutive inputs?
Code: I modified the code of 'generate/lora.py' to enable consecutive generation on multiple inputs. Basically, just add a for loop and model.reset_cache()
# support multiple inference
outputs = []
num_samples = len(input)
for i in range(num_samples):
sample = {"instruction": prompt[i], "input": input[i]}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
t0 = time.perf_counter()
output = generate(
model,
idx=encoded,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
model.reset_cache()
output = tokenizer.decode(output)
output = output.split("### Response:")[1].strip()
print(output)
print(f"Time for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
outputs.append(output)