LLama2 convert_to_onnx.py does not fuse GQA
turneram opened this issue · comments
turneram commented
Describe the issue
Using onnxruntime.transformers.models.llama.convert_to_onnx
with --use_gqa
does not produce a model with GroupQueryAttention nodes.
To reproduce
python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa
The only fusions reported are SimplifiedLayernormalization and SkipSimplifiedLayernormalization, and the final model only contains attention in the un-fused format pictured below:
Urgency
No response
Platform
Linux
OS Version
22.04
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 12.1.1