mlfoundations / open_lm

A repository for research on medium sized language models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problem in position embedding

jmercat opened this issue · comments

queries, keys, vals = self.pos_embed(queries, keys, vals)

It seems to me that the rotary position embedding is being applied on the head dimension (dim -2) of the vectors q, k instead of the sequence dimension (dim 1).
I think the head and sequence dimensions should be swapped before calling position embedding .
(see https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py#L85)

What I'm proposing is simply to re-write RotaryWithCast as follow:

class RotaryWithCast(RotaryEmbedding):
    def forward(self, q, k, v):
        q, k = super().forward(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3))
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        return q.to(v.dtype), k.to(v.dtype), v

Here is the runs I made with a custom subset of starcoder data. The original 11m training is in brown. My implementation using a different positional encoding (including the proposed fix) is in orange.
Screenshot from 2023-08-31 10-22-09

Good catch! The blow up curves your are seeing are similar to the ones we were seeing before we introduced qk norm for the smaller models. Will do some testing with this fix on my end as well. Would you like to open a PR?

Wow, amazing catch! We really appreciate this.

We've added your name to the README because this is a very substantial bug catch. It's pretty interesting that our first 1B/7B runs do pretty well even without proper posembeds, but we should fix this going forward.

Great code base by the way. It's a pleasure to read.
Thanks for proposing to include me. I could open a PR but it's probably simpler for you to just include what I wrote (or a better version... I haven't tested if calling contiguous would make a difference).

looking into a way to implement this directly with the xformers api. thanks so much @jmercat !

actually moving that line before the call to view would be enough.

queries, keys, vals = self.pos_embed(queries, keys, vals)

The problem actually seems to be upstream in xformers. Opened an issue here: facebookresearch/xformers#841