lucidrains / bidirectional-cross-attention

A simple cross attention that updates both the source and target in one step

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Bidirectional Cross Attention

A simple cross attention that updates both the source and target in one step. The key insight is that one can do shared query / key attention and use the attention matrix twice to update both ways. Used for a contracting project for predicting DNA / protein binding here.

Install

$ pip install bidirectional-cross-attention

Usage

import torch
from bidirectional_cross_attention import BidirectionalCrossAttention

video = torch.randn(1, 4096, 512)
audio = torch.randn(1, 8192, 386)

video_mask = torch.ones((1, 4096)).bool()
audio_mask = torch.ones((1, 8192)).bool()

joint_cross_attn = BidirectionalCrossAttention(
    dim = 512,
    heads = 8,
    dim_head = 64,
    context_dim = 386
)

video_out, audio_out = joint_cross_attn(
    video,
    audio,
    mask = video_mask,
    context_mask = audio_mask
)

# attended output should have the same shape as input

assert video_out.shape == video.shape
assert audio_out.shape == audio.shape

Todo

  • allow for cosine sim attention

Citations

@article{Hiller2024PerceivingLS,
    title   = {Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers},
    author  = {Markus Hiller and Krista A. Ehinger and Tom Drummond},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2402.12138},
    url     = {https://api.semanticscholar.org/CorpusID:267751060}
}

About

A simple cross attention that updates both the source and target in one step

License:MIT License


Languages

Language:Python 100.0%