[BUG] [DOCS] XPOS
evelynmitchell opened this issue · comments
evelynmitchell commented
There's a dimensional error in the 3rd example of Xpos:
Code
import torch
from zeta import fixed_pos_embedding, apply_rotary_pos_emb
# Generate fixed positional embeddings
scale = torch.randn(10, 256)
sin, cos = fixed_pos_embedding(scale)
# Apply rotary positional embeddings to an input tensor
x = torch.randn(1, 10, 256)
output = apply_rotary_pos_emb(x, sin, cos, scale=0.5)
RuntimeError Traceback (most recent call last)
[<ipython-input-18-4d63cb090aa8>](https://localhost:8080/#) in <cell line: 10>()
8 # Apply rotary positional embeddings to an input tensor
9 x = torch.randn(1, 10, 256)
---> 10 output = apply_rotary_pos_emb(x, sin, cos, scale=0.5)
[/usr/local/lib/python3.10/dist-packages/zeta/nn/embeddings/xpos_relative_position.py](https://localhost:8080/#) in apply_rotary_pos_emb(x, sin, cos, scale)
69 """
70 sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
---> 71 return (x * cos) + (rotate_every_two(x) * sin)
72
73
RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 2
The implementation in the original paper is wrong. This wrong implementation was copied into hf/transformers, and then fixed:
huggingface/transformers@052fa2f
https://github.com/huggingface/transformers/blob/edb170238febf7fc3e3278ed5b9ca0b2c40c70e3/src/transformers/models/gptj/modeling_flax_gptj.py#L122
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.
Kye Gomez commented
@evelynmitchell the link you posted is not in pytorch, they are very different.