lucidrains / compositional-attention-pytorch

Implementation of "compositional attention" from MILA, a multi-head attention variant that is reframed as a two-step attention process with disentangled search and retrieval head aggregation, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Compositional Attention - Pytorch

Implementation of Compositional Attention from MILA. They reframe the "heads" of multi-head attention as "searches", and once the multi-headed/searched values are aggregated, there is an extra retrieval step (using attention) off the searched results. They then show this variant of attention yield better OOD results on a toy task. Their ESBN results still leaves a lot to be desired, but I like the general direction of the paper.

Install

$ pip install compositional-attention-pytorch

Usage

import torch
from compositional_attention_pytorch import CompositionalAttention

attn = CompositionalAttention(
    dim = 1024,            # input dimension
    dim_head = 64,         # dimension per attention 'head' - head is now either search or retrieval
    num_searches = 8,      # number of searches
    num_retrievals = 2,    # number of retrievals
    dropout = 0.,          # dropout of attention of search and retrieval
)

tokens = torch.randn(1, 512, 1024)  # tokens
mask = torch.ones((1, 512)).bool()  # mask

out = attn(tokens, mask = mask) # (1, 512, 1024)

Citations

@article{Mittal2021CompositionalAD,
    title   = {Compositional Attention: Disentangling Search and Retrieval},
    author  = {Sarthak Mittal and Sharath Chandra Raparthy and Irina Rish and Yoshua Bengio and Guillaume Lajoie},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09419}
}

About

Implementation of "compositional attention" from MILA, a multi-head attention variant that is reframed as a two-step attention process with disentangled search and retrieval head aggregation, in Pytorch

License:MIT License


Languages

Language:Python 100.0%