microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

Home Page:https://onnxruntime.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

LLama2 convert_to_onnx.py does not fuse GQA

turneram opened this issue · comments

Describe the issue

Using onnxruntime.transformers.models.llama.convert_to_onnx with --use_gqa does not produce a model with GroupQueryAttention nodes.

To reproduce

python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa
The only fusions reported are SimplifiedLayernormalization and SkipSimplifiedLayernormalization, and the final model only contains attention in the un-fused format pictured below:
image

Urgency

No response

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

030a961

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.1.1