lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

rotary embedding values

sklin93 opened this issue · comments

The reason for using value -10 here is " image axial positions have range [-1, 1]", but in fact, in the RotaryEmbedding class, for
img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel')
img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))

freqs (has its default value here) and t (ranges in [-1, 1]) will einsum, and the results are not in range [-1, 1]. (e.g. if using head_dim=64, aka rot_dim=21, then the img_freqs with range from -15.7 to 15.7) and I think this makes the choice of -10 not suitable?