austinsilveria / fstattention

Memory bandwidth efficient sparse tree attention

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Fused Sparse Tree Attention

Memory bandwidth efficient sparse tree attention

  • (precompute) chunk the tree into query blocks
  • (precompute) compute unique ancestors, attention mask, and leaves for each block
  • (runtime) only load keys and values for the query block's unique ancestors and leaves
  • (runtime) go fast
  • A100 Colab Benchmark

go forth, search the tree of possible futures

Screen Shot 2024-02-25 at 14 52 59

Notes on precomputation:

  • Can probably make this fast enough for runtime with a bit more work since for a dynamic tree structure (i.e. dependent on the model's output), we only need to compute these kernel inputs once, and then they get reused by all attention layers in the model
  • Static tree structures are still useful: Medusa uses a size 256 static left weighted tree that gets populated via cartesian products of their multiple topk output heads to accelerate batch size 1 inference by ~3x

Todo:

  • Organize blocks based on DFS odering to minimize the number of blocks that need to load the same ancestor KVs (i.e. maximize the shared lineage of each block)

Credits:

About

Memory bandwidth efficient sparse tree attention


Languages

Language:Python 100.0%