Seq len missing in rotary embedding
raganato opened this issue · comments
raganato commented
the forward of the RotaryEmbedding lacks the seq len input argument. I think we just need to add it, and then at line 1273 rotary_pos_emb = self.rotary_pos_emb(pos) include it as x.shape[1]
Phil Wang commented
raganato commented
it should break with the following setting, so when the xpos is set to True and it goes in line 450
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L450
model = TransformerWrapper(
num_tokens = 10,
max_seq_len = 20,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8,
rotary_xpos = True, # modified rotary to extrapolate well beyond length at which it was trained
)
)