pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Understanding why TorchInductor cannot speed-up huggingface transformer inference

learning-chip opened this issue · comments

Problem

torch.compile() shows an impressive ~2x speed-up for this code repo, but when applying to huggingface transformers there is barely no speed-up. I want to understand why, and then figure out how TorchInductor can also benefit HF models (related issue #9)

Comparing HF's model.generate() vs gpt-fast under the same setting (same prompt, output length, sampling, data type, ...), I found that (on RTX 4090):

  • In eager mode without compile(), HF generate() (39.4 token/s) is faster than gpt-fast (28 token/s)
  • In compiled mode, HF generate() has almost no speed-up (still 39.4 token/s); gpt-fast gets much faster (68.5 token/s)

The blog mentions statically allocating KV cache, but isn't this also implemented in the HF llama model?

Benchmark code

GPT-fast

cd gpt-fast
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
./scripts/prepare.sh $MODEL_REPO

python generate.py --prompt "Q: What is the largest animal?\nA:"  --max_new_tokens 134 --num_samples 1 --checkpoint_path checkpoints/$MODEL_REPO/model.pth
python generate.py --compile --prompt "Q: What is the largest animal?\nA:" --max_new_tokens 134 --num_samples 1 --checkpoint_path checkpoints/$MODEL_REPO/model.pth

--max_new_tokens 134 is to match HF's output length, as this gpt-fast repo continues to generate text even when hitting the end token </s>.

HuggingFace

Run the script below by

python ./hf_generate.py --compile --do_sample
import time
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed


def print_separater():
    print("=" * 20, "\n")

def get_model_and_tokenizer(model_path, device, dtype):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=dtype,
        device_map=device
    )
    model.tokenizer = tokenizer
    return model, tokenizer

def benchmark_throughput(model, model_inputs, args):
    device = model.device
    set_seed(args.seed)

    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()
    greedy_output = model.generate(
        **model_inputs,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.do_sample,
        top_k=args.top_k,
        temperature=args.temperature,
    )
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()

    time_elasped = t1 - t0
    num_tokens = greedy_output.numel() - model_inputs['input_ids'].numel()

    print("Output:\n" + 100 * '-')
    print(model.tokenizer.decode(greedy_output[0], skip_special_tokens=False))

    print("Generated Tokens:", num_tokens)
    print("Time Elasped (s):", time_elasped)
    throughput = num_tokens/ time_elasped

    return throughput

def main(args):
    print("torch and transformer version:", torch.__version__, transformers.__version__)
    print(torch.__config__.parallel_info())
    print(f"device: {args.device}, dtype: {args.dtype}")
    print(f"model: {args.model_path}")
    print_separater()

    model, tokenizer = get_model_and_tokenizer(args.model_path, args.device, args.dtype)
    model_inputs = tokenizer(args.prompt, return_tensors='pt').to(args.device)

    warm_up_tokens = 20
    set_seed(args.seed)
    warm_up_output = model.generate(**model_inputs, max_new_tokens=warm_up_tokens)

    throughput = benchmark_throughput(model, model_inputs, args)
    print("throughput eager (token/s):", throughput)

    if args.compile:
        t0 = time.time()
        compiled_model = torch.compile(
            model,
            backend=args.dynamo_backend,
            mode=args.dynamo_mode,
            dynamic=None,
            fullgraph=True,
            disable=False
            )
        t1 = time.time()
        print("Compile time (s):", t1 - t0)

        set_seed(args.seed)
        warm_up_output_compiled = compiled_model.generate(
            **model_inputs, max_new_tokens=warm_up_tokens)
        print("Warm-up result agree:", torch.equal(warm_up_output, warm_up_output_compiled))
        print_separater()

        throughput_compiled = benchmark_throughput(compiled_model, model_inputs, args)
        print("throughput compiled (token/s):", throughput_compiled)

        print_separater()
        print("compile speed-up:", throughput_compiled / throughput)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Your CLI description.')

    parser.add_argument('--device', type=str,
                        default="cuda")
    parser.add_argument('--dtype', default=torch.float16)
    parser.add_argument('--model_path', type=str,
                        default="meta-llama/Llama-2-7b-chat-hf", help='HF model name or path.')
    parser.add_argument('--prompt', type=str,
                        default="Q: What is the largest animal?\nA:", help='Input prompt.')
    parser.add_argument('--max_new_tokens', type=int,
                        default=256, help='Maximum number of new tokens.')
    parser.add_argument('--do_sample', action='store_true',
                        help='Whether to use sampling. Default is greedy search.')
    parser.add_argument('--top_k', type=int,
                        default=200, help='Top-k for sampling.')
    parser.add_argument('--temperature', type=float,
                        default=0.8, help='Temperature for sampling.')
    parser.add_argument('--compile', action='store_true',
                        help='Whether to compile the model.')
    parser.add_argument('--dynamo_backend', type=str,
                        default="inductor", help='torch._dynamo.list_backends()')
    parser.add_argument('--dynamo_mode', type=str,
                        default="default", help='["default", "reduce-overhead", "max-autotune"]')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')

    args = parser.parse_args()
    main(args)

The default sampling settings are the same as this repo's generate.py

Output results

gpt-fast:

Loading model ...
Time to load model: 6.07 seconds
Q: What is the largest animal?\nA: The largest animal on Earth is the blue whale. On average, an adult blue whale can grow up to 82 feet (25 meters) in length and weigh around 150-170 tons (136,000-152,000 kilograms). However, the largest blue whale ever recorded was a female that was found in 1947 off the coast of Iceland, which measured around 108 feet (33 meters) in length and weighed an estimated 210 tons (182,000 kilograms).
Time for inference 1: 4.78 sec total, 28.02 tokens/sec
Bandwidth achieved: 377.67 GB/s
==========
Average tokens/sec: 28.02
Memory used: 13.59 GB

For eager, output texts are the same as Huggingface, although random seed settings are different from HF script.

Time to load model: 6.26 seconds
Compilation time: 26.94 seconds
Q: What is the largest animal?\nA: The largest animal on Earth is the blue whale. It can grow up to 33 meters (108 feet) in length and weigh up to 180 metric tons (200 tons).t is important to note that the size of a blue whale can vary greatly depending on its age, sex, and other factors. Adult blue whales typically range in length from 18 to 25 meters (59 to 82 feet), with an average length of around 19 meters (62 feet).

Other large animals include:

1. Fin Whale: The fin whale
Time for inference 1: 1.95 sec total, 68.56 tokens/sec
Bandwidth achieved: 923.91 GB/s
==========
Average tokens/sec: 68.56
Memory used: 13.85 GB

With Inductor, the output texts becomes different (not sure due to random seed or float-point issues), although still sensible.

Huggingface:

Output:
----------------------------------------------------------------------------------------------------
<s> Q: What is the largest animal?
A: The largest animal on Earth is the blue whale. On average, an adult blue whale can grow up to 82 feet (25 meters) in length and weigh around 150-170 tons (136,000-152,000 kilograms). However, the largest blue whale ever recorded was a female that was found in 1947 off the coast of Iceland, which measured around 108 feet (33 meters) in length and weighed an estimated 210 tons (182,000 kilograms).</s>
Generated Tokens: 134
Time Elasped (s): 3.39901065826416
throughput eager (token/s): 39.42323619203725
Compile time (s): 0.0032820701599121094
Warm-up result agree: True
==================== 

Output:
----------------------------------------------------------------------------------------------------
<s> Q: What is the largest animal?
A: The largest animal on Earth is the blue whale. On average, an adult blue whale can grow up to 82 feet (25 meters) in length and weigh around 150-170 tons (136,000-152,000 kilograms). However, the largest blue whale ever recorded was a female that was found in 1947 off the coast of Iceland, which measured around 108 feet (33 meters) in length and weighed an estimated 210 tons (182,000 kilograms).</s>
Generated Tokens: 134
Time Elasped (s): 3.404815673828125
throughput compiled (token/s): 39.356021834021995
==================== 

compile speed-up: 0.9982950573187892

Environment

  • torch-2.3.0.dev20231217+cu121
  • transformers-4.36.1
  • tokenizers-0.15.0
  • accelerate-0.25.0

Torch installed by

pip install --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121

which grabs https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20231217%2Bcu121-cp310-cp310-linux_x86_64.whl

Similar results with torch 2.1.2+cu121 #46 (comment)

I think HF llama does not have a static kv cache, since its cache is dynamically increased during generation. Here is the relavent code: https://github.com/huggingface/transformers/blob/38611086d293ea4a5809bcd7fadd8081d55cb74e/src/transformers/models/llama/modeling_llama.py#L1014C37-L1014C37
However, I also have the same doubt about why compile hardly accelerate HF model? Is it becase the input size of model in each step of generation is different and results in frequent recompile?

Yes! Static KV cache is not supported but coming soon!

@learning-chip @ArthurZucker
Hi both, I am comparing HF with GPT-fast as well and cannot get the same pass@1 score. When using greedy method, I cannot get the exact same predictions from both APIs. I have submitted an issue (#94 ). Could you provide some pointers? I am stuck. Thanks, Yao Fehlis (yao.fehlis@amd.com)

Closing since core issue in huggingface was a dynamic KV cache which was made static