lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] Error when `rotary_pos_emb` set to True in cross attention

BakerBunker opened this issue · comments

import torch
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim=512, depth=6)
model = CrossAttender(
    dim=512,
    depth=6,
    rotary_pos_emb=True,
    attn_flash=True,
)

nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()

neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()

encoded_neighbors = enc(neighbors, mask=neighbor_masks)
model(
    nodes, context=encoded_neighbors, mask=node_masks, context_mask=neighbor_masks
)  # (1, 1, 512)

hmm, is the source and target sequence in some shared coordinate space? usually you cannot use rotary embeddings in cross attention

Thank you for explanation, it's my fault to use rotary embedding in cross attention

@BakerBunker no problem, i should have added an assert to prevent this in cross attention setting