CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FoT attention and the scaling trick

StrangeTcy opened this issue · comments

In your paper, you say

Position Interpolation (PI, [Chen
et al., 2023] and [kaiokendev, 2023]) introduces a modification to the rotary positional encoding
scheme that enables fine-tuning for 32K context. In contrast to this work, our method does not
rely on positional encodings, following the findings from [Haviv et al., 2022]. Removing positional
encoding in memory allows us to extrapolate to 256k tokens, although the model was only trained on
sequences up to 8K, yielding theoretically unbounded context length.

Does that mean that one can't use both scaled positional embeddings and FoT attention?

commented

I think its due to applied FoT attention, that not use scaled positional embeddings by summing the additional parts

Hi, thanks for the question. Briefly speaking, we have not tried using scaled positional encodings and FoT attention, so we cannot comment on performance.

Originally FoT was designed to allow the model to handle large databases consisting of millions of keys and values from multiple unrelated documents. In such a setup, it is not clear how to apply positional encodings. It is reflected in our experiments with smaller models where we disable positional encodings in memory layers (other layers maintain positional encoding).
There is a slight difference in LongLLaMA models. Mainly all layers except memory layers use positional encodings in the standard way. Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.

In other words, let
$$t_0, t_1, t_2, t_3, \ldots t_{2047}, t_{2048}, \ldots, t_{4095}, \ldots$$
be some input.
LongLLaMA will process it in context windows. First, it will process
$$t_0, t_1, t_2, t_3, \ldots t_{2047}$$
and move the (key, value) pairs from memory layers to the memory cache,
Then it will process
$$t_{2048}, \ldots, t_{4095}$$
In this step, non-memory layers process only 2048 embeddings,
whereas memory layers see also previous embeddings (keys and values), but as if they were located at the same position as $t_{2048}$.

We do this in order to maintain compatibility with the LLaMA code.

I figured as much after a re-reading of the respective parts of the paper, but the whole "they encode them as if they were at the beginning of the local context" wasn't very clear to me until your explanation, so thanks for that.