Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.
$ pip install memory-efficient-attention-pytorch
For autoregressive language model
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention(
dim = 512,
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
causal = True, # autoregressive or not
memory_efficient = True, # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024, # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
).cuda()
x = torch.randn(1, 16384, 512).cuda()
out = attn(x) # (1, 16384, 512)
Cross attention
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention(
dim = 512,
dim_head = 64,
heads = 8,
memory_efficient = True,
q_bucket_size = 1024,
k_bucket_size = 2048
).cuda()
x = torch.randn(1, 16384, 512).cuda()
context = torch.randn(1, 16384, 512).cuda()
mask = torch.ones(1, 16384).bool().cuda()
out = cross_attn(x, context = context, mask = mask) # (1, 16384, 512)
- add enwik8 example with 8192 context length
- offer version of memory efficient attention w/o numerical stability calculations, when used in conjunction with cosine sim attention from SwinV2 paper
- benchmark and see how much torch jit helps
@misc{rabe2021selfattention,
title = {Self-attention Does Not Need $O(n^2)$ Memory},
author = {Markus N. Rabe and Charles Staats},
year = {2021},
eprint = {2112.05682},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}