gfvvz / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

中文版

Introduction

FlagGems is a high-performance general operator library implemented in OpenAI Triton. It aims to provide a suite of kernel functions to accelerate LLM training and inference.

By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment.

Changelog

v1.0

  • support BLAS operators: addmm, bmm, mm
  • support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu
  • support reduction operators: cumsum, layernorm, mean, softmax

v2.0

  • support BLAS operator: mv, outer
  • support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid
  • support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm
  • support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding

Quick Start

Requirements

  1. Triton >= 2.2.0
  2. PyTorch >= 2.1.2
  3. Transformers >= 4.40.2

Installation

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .

Usage

Import

  1. Enable permanently

    import flag_gems
    flag_gems.enable()
  2. Enable temporarily

    import flag_gems
    with flag_gems.use_gems():
        pass
  3. Example

    import torch
    import flag_gems
    
    M, N, K = 1024, 1024, 1024
    A = torch.randn((M, K), dtype=torch.float16, device="cuda")
    B = torch.randn((K, N), dtype=torch.float16, device="cuda")
    with flag_gems.use_gems():
        C = torch.mm(A, B)

Execute

  1. Test Operator Accuracy

    • Run reference on cuda
      cd tests
      pytest test_xx_ops.py
    • Run reference on cpu
      cd tests
      pytest test_xx_ops.py --device cpu
  2. Test Model Accuracy

    cd examples
    pytest model_xx_test.py
  3. Test Operator Performance

    • Test CUDA performance
      cd benchmark
      pytest test_xx_perf.py -s
    • Test end-to-end performance
      cd benchmark
      pytest test_xx_perf.py -s --mode cpu
  4. Run tests with logging infomation

    pytest program.py --log-cli-level debug

    Not recommended in performance testing.

Supported Operators

Operators will be implemented according to OperatorList.md.

Supported Models

  • Bert-base-uncased
  • Llama-2-7b

Supported Platforms

Platform float16 float32 bfloat16
Nvidia A100

Contributions

If you are interested in contributing to the FlagGems project, please refer to CONTRIBUTING.md. Any contributions would be highly appreciated.

Contact us

If you have any questions about our project, please submit an issue, or contact us through flaggems@baai.ac.cn.

License

The FlagGems project is based on Apache 2.0.

About

FlagGems is an operator library for large language models implemented in Triton Language.

License:Apache License 2.0


Languages

Language:Python 100.0%