facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.

Home Page:https://facebookresearch.github.io/xformers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RotaryEmbedding applied to the incorrect channel dimension

sagadre opened this issue Β· comments

πŸ› Bug

Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.

Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.

def forward(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
            k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
        )

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

Additional context

Thanks to @jmercat who found symptoms of this problem downstream of xformers!