lucidrains / infini-transformer-pytorch

Implementation of Infini-Transformer in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Infini-Transformer - Pytorch

Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.

Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.

Yannic Kilcher's explanation

Install

$ pip install infini-transformer-pytorch

Usage

import torch
from infini_transformer_pytorch import InfiniTransformer

transformer = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,  # high head dimension may be part of the reason they got good results (kv has high capacity)
    heads = 8,
    use_mem_delta_rule = True
)

x = torch.randint(0, 256, (1, 1024))

logits1, _, mem1 = transformer(x, return_new_memories = False)
logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)

Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper

import torch

from infini_transformer_pytorch import (
    InfiniTransformer,
    InfiniTransformerWrapper
)

# model and wrapper

model = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,
    heads = 8,
    use_mem_delta_rule = True
)

wrapper = InfiniTransformerWrapper(
    model,
    segment_length = 512,
    detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()

# mock input

seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence

# training

loss = wrapper(
    seq,
    backward = True # will automatically segment and accumulate gradients when it detaches the memories
)

# after much data...

# calculating eval loss

with torch.no_grad():
    wrapper.eval()
    eval_loss = wrapper(seq)

# generating is as easy as

output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])

output.shape # (2, 8192 - 1)

Testing

Train an autoregressive enwik8

$ python train.py

Todo

  • detach_mems_every_num_segments hyperparameter is too confusing, get rid of it
  • experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
  • working example with enwik8

Citations

@inproceedings{Munkhdalai2024LeaveNC,
    title   = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
    author  = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:269033427}
}

About

Implementation of Infini-Transformer in Pytorch

License:MIT License


Languages

Language:Python 100.0%