lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Seq len missing in rotary embedding

raganato opened this issue · comments

def forward(self, t):

the forward of the RotaryEmbedding lacks the seq len input argument. I think we just need to add it, and then at line 1273 rotary_pos_emb = self.rotary_pos_emb(pos) include it as x.shape[1]

it should break with the following setting, so when the xpos is set to True and it goes in line 450
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L450

model = TransformerWrapper(
    num_tokens = 10,
    max_seq_len = 20,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8,
        rotary_xpos = True,   # modified rotary to extrapolate well beyond length at which it was trained
    )
)

@raganato oops, you are correct

should be fixed, thank you!