[Bug] Error when `rotary_pos_emb` set to True in cross attention
BakerBunker opened this issue · comments
BakerBunker commented
import torch
from x_transformers import Encoder, CrossAttender
enc = Encoder(dim=512, depth=6)
model = CrossAttender(
dim=512,
depth=6,
rotary_pos_emb=True,
attn_flash=True,
)
nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()
neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()
encoded_neighbors = enc(neighbors, mask=neighbor_masks)
model(
nodes, context=encoded_neighbors, mask=node_masks, context_mask=neighbor_masks
) # (1, 1, 512)
Phil Wang commented
hmm, is the source and target sequence in some shared coordinate space? usually you cannot use rotary embeddings in cross attention
BakerBunker commented
Thank you for explanation, it's my fault to use rotary embedding in cross attention
Phil Wang commented
@BakerBunker no problem, i should have added an assert to prevent this in cross attention setting