mfuntowicz / llama2.rs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

llama2.rs

This is a one-file Rust implementation of Llama2.

  • Support for 4-bit GPT-Q Quantization
  • SIMD support for fast CPU inference
  • Support for Grouped Query Attention (needed for big Llamas)
  • Memory mapping, loads 70B instantly.
  • Static size checks, no pointers

Can run up on 1 tok/s 70B Llama2. (intel i9 with avx512)

To build (requires +nightly to use SIMD, get with rustup):

> rustup toolchain install nightly # to get nightly
> cargo +nightly build --release

If you get an error you may need to change .cargo/config to match your chipset. Looks for AVX512 by default.

To get model (loads 70B quantized):

pip install torch transformers auto-gptq
python export.py llama2-70b-q.bin

To run:

> target/release/llama2_rs llama2-70b-q.bin 0.0 11 "The only thing"                                                                                                                                 
The only thing that I can think of is that the          
achieved tok/s: 0.89155835

Honestly, not so bad for running on my GPU machine, significantly faster than llama.c.

Here's a run of 13B quantized:

> target/release/llama2_rs llama2_7b.bin 0.0 11 "One thing is that"
One thing is that the 100% of the people who are in the 1%
achieved tok/s: 4.027984

Here's a run of 7B non-quantized (this is less optimized):

> target/release/llama2_rs llama2_7b.bin 0.0 11 "The only thing"
The only thing that is certain in life is change.
achieved tok/s: 1.0298662

Configuration

In order to make the model as fast as possible, you need to compile a new version to adapt to other Llama versions. Currently this is done by a group of constant headers at the top of the file. The model will fail if these disagree with the binary model that is being loaded.

// Configuration for Llama 70B. Others in config.txt                                                                                          
const DIM: usize = 8192;                                                                                                                      
const HIDDEN_DIM: usize = 28672;                                                                                                              
const ATTN_GROUPS: usize = 8;                                                                                                                 
const N_LAYERS: usize = 80;                                                                                                                   
const N_HEADS: usize = 64;                                                                                                                    
const SEQ_LEN: usize = 2048;                                                                                                                  
const VOCAB_SIZE: usize = 32000;                                                                                                              
                                                                                                                                              
// Grouped Query Attention                                                                                                                    
const KV_DIM: usize = DIM / ATTN_GROUPS;                                                                                                      
const N_KV_HEADS: usize = N_HEADS / ATTN_GROUPS;                                                                                              
const HEAD_SIZE: usize = DIM / N_HEADS;                                                                                                       
                                                                                                                                              
// Turn on GPT-Q Quantization.                                                                                                                
type TWeights = QTransformerWeights;                                                                                                          
const BITS: usize = 4;                                                                                                                        
const GROUPSIZE: usize = 128; 

See Also

Originally, a Rust port of Karpathy's llama2.c but now has a bunch more features to make it scale to 70B.

Also check out:

How does it work?

Started as a port of the original code, with extra type information to make it easier to extend.

There are two dependencies:

  • memmap2for memory mapping
  • rayon for parallel computation.
  • SIMD enabled support with +nightly.

Why?

Mostly this was an exercise in learning some Rust. Was curious how you port over things like memory mapping, parallel processing, and some of the mathematical tricks.

This is my first Rust project, so if you are an expert I would love a code review!

About


Languages

Language:Rust 50.1%Language:C 41.3%Language:Python 8.6%