bratao / memformer

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Memformer - Pytorch (wip)

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time. The other contribution of this paper is a simplified relative positional encoding that performs better with less parameter and compute.

Install

$ pip install memformer

Usage

import torch
from memformer import Memformer

model = Memformer(
    num_tokens = 256,
    dim = 512,
    depth = 2,
    max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2
)

x1 = torch.randint(0, 256, (1, 1024))
y1 = torch.randint(0, 256, (1, 1024))

x2 = torch.randint(0, 256, (1, 1024))
y2 = torch.randint(0, 256, (1, 1024))

tgt_out1, mems1 = model(x1, y1) # (1, 1024, 512), (1, 128, 512)
tgt_out2, mems2 = model(x2, y2, mems = mems1)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}

About

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

License:MIT License


Languages

Language:Python 100.0%