ofirpress / attention_with_linear_biases

Code for the ALiBi method for transformer language models (ICLR 2022)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Modifying ALiBi for Encoder-Attention or Cross-Attention

ofirpress opened this issue · comments

In our paper we only showed results on causal language models, which use causally masked (decoder) self-attention.

If you'd like to use ALiBi for seq2seq tasks such as translation, speech or T5, or if you'd like to use ALiBi for masked language models such as BERT, some modifications are required.

Encoder-Attention

Encoder-Attention is the non-masked self-attention that is performed in the encoder of seq2seq models such as translation models or T5. This is also the same kind of attention used in MLM models such as BERT.

We can't naively copy paste the ALiBi code for these models because it won't work. We use a trick to quickly calculate the bias matrix for causal language modeling, but this bias matrix is only correct for values in or below the main diagonal (since that's all that's used in causal language modeling).

            maxpos = args.tokens_per_sample
            attn_heads = args.encoder_attention_heads  
           
            context_position = torch.arange(maxpos)[:, None].cuda()
            memory_position = torch.arange(maxpos)[None, :].cuda()
            relative_position = memory_position - context_position 
            relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)

This code correctly generates the full bias matrix. Note that the bias matrix is symmetric around the diagonal, since it computes the absolute distance between the query and key (so all distances are positive).

We're also going to need the code for generating the ALiBi slopes:

            def get_slopes(n):
                def get_slopes_power_of_2(n):
                    start = (2**(-2**-(math.log2(n)-3)))
                    ratio = start
                    return [start*ratio**i for i in range(n)]

                if math.log2(n).is_integer():
                    return get_slopes_power_of_2(n)                   #In the paper, we only train models that have 2^a heads for some a. This function has
                else:                                                 #some good properties that only occur when the input is a power of 2. To maintain that even
                    closest_power_of_2 = 2**math.floor(math.log2(n))  #when the number of heads is not a power of 2, we use this workaround. 
                    return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

There are 3 options for implementing encoder-attention ALiBi:

  1. Symmetric: In this option, the bias we assign to query/key pairs that are +N or -N tokens apart will be the same.
                self.slopes = torch.Tensor(get_slopes(attn_heads)).cuda()*-1
                self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
                self.alibi = self.alibi.view(1, attn_heads, maxpos, maxpos)

Now just pass self.alibi to the attention function and add it after the query*key computation.

In fairseq for example, the query*key computation is done as such: attn_weights = torch.bmm(q, k.transpose(1, 2)), and then to add the ALiBi values use:

attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights += alibi[:,:,:tgt_len,:src_len].to(attn_weights)
attn_weights = attn_weights.view(bsz*self.num_heads, tgt_len, src_len)
  1. Nonsymmetric: Here we are going to make the model nonsymmetric by using the same ALiBi bias as in (1), but this time, we're going to let the first half of the heads only look left and the second half only look right. We'll do this by adding a mask to our attention.
    Note: This code hasn't been fully tested yet and might contain bugs.
                self._future_mask_right = torch.triu(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
                self._future_mask_left = torch.tril(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
                
                self.nonsym_mask = torch.cat((self._future_mask_right, self._future_mask_left), dim = 0).unsqueeze(0).cuda()
                self.slopes = torch.Tensor(get_slopes(attn_heads//2)).cuda()*-1
                
                context_position = torch.arange(maxpos)[:, None].cuda()
                memory_position = torch.arange(maxpos)[None, :].cuda()
                relative_position = memory_position - context_position
                relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads//2, -1,-1)

                self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
                self.alibi = self.alibi.view(1, attn_heads//2, maxpos, maxpos)
                self.alibi = self.alibi.repeat(1, 2, 1, 1).cuda()

Again, as before, add self.alibi to the attn-weights, but this time also add the nonsym_mask tensor. (In fairseq attn_weights += nonsym_mask[:,:,:tgt_len,:src_len].to(attn_weights))

  1. Nonsymmetric with no mask: In this approach, we don't use any masking, but instead we make the positioning non-symmetric by using different ALiBi slopes depending on whether the key is to the left or right of the query. Here, we use learned slopes but you can also do this with non-learned slopes.
    Note: I haven't tested this code so it might contain bugs!
slopes_left = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
nn.init.normal_(slopes_left, -2,1)
slopes_right = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
nn.init.normal_(slopes_right, -2,1)

slopes_left = -torch.sigmoid(slopes_left)
slopes_right = -torch.sigmoid(slopes_right)

context_position = torch.arange(maxpos)[:, None]
memory_position = torch.arange(maxpos)[None, :]
relative_position = memory_position - context_position
relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)

alibi_left = slopes_left.unsqueeze(1).unsqueeze(1) * relative_position
alibi_right = slopes_right.unsqueeze(1).unsqueeze(1) * relative_position

self.alibi = torch.triu(alibi_right) + torch.tril(alibi_left)
  1. Check out the variation on option 3 from the LittleBird paper.

Cross-Attention

For translation models and models like T5 you will also need to implement cross-attention, which is the attention from the decoder to the encoder. The T5 model uses no positional information in cross-attention and I would recommend doing the same thing.

Implementations

NEW: lucidrains/x-transformers#88 lucidrains has implemented some of the above ideas in the x-transformers repo.

Hello @ofirpress,

I have a 2 questions regarding moving from positional embedding to alibi.

Assuming we are using positional embedding of max 512 tokens.
The transformer can forward any sequence length up to 512 tokens without any modification.

When using alibi, you create the bais matrix "on the fly" on each forward pass to support different sequence lengths? or you creating the bais matrix once and just pass it to the transformer ? the second option restrict the transformer to forward only sequences of length = maxpos

maybe to define alibi once with maxpos and while adding alibi to do:

seq_len = attention_scores.size()[-1]
attention_scores += self.alibi[:, :, :seq_len, :seq_len]

the second question is:
Can we utilize the pretrained LM models like BERT to use alibi or we need to train new LM from scratch? more specific, LM like BERT trained with positional embedding, we can just refactor the code to move to alibi and then finetune from pretrained model? or do we need to train LM with alibi from scratch ?

Thanks in advance

or you creating the bais matrix once and just pass it to the transformer

Right now that's what we do, we create the ALiBi tensor once and add it to the mask. In case the current sequence length is shorter than maxpos then the mask is just cut down to the right dimensions.

If you have sequences of different lengths, there's always going to be a maxpos which is the maximum that you expect your model to ever see, just make ALiBi that size and cut it down when the sequences are shorter.

Can we utilize the pretrained LM models like BERT to use alibi or we need to train new LM from scratch?

I haven't experimented with this too much but I think that if you have a model that was trained with sinusoidal or learned embeddings you would have to retrain it from scratch if you want to use ALiBi. It would be interesting to experiment with just finetuning with ALiBi, I have no idea if that would work or not. If you do end up trying the finetuning method tell me how it goes, I'm curios!

just make ALiBi that size and cut it down when the sequences are shorter.

like that ?

seq_len = attention_scores.size()[-1]
attention_scores += self.alibi[:, :, :seq_len, :seq_len]

If you do end up trying the finetuning method tell me how it goes, I'm curios!

I will !

do you know about pretrained models with ALiBi?

Yup the code you posted looks good.
We've released some trained models on WikiText-103 here.

Hi @ofirpress,

I just saw this issue and the paper you co-authored, Transformer Language Models without Positional Encodings Still Learn Positional Information. So, which of the 3 options did you go with for the MLM ALiBi experiment (Table 4)? 😉

I think it was option 2 but I just emailed Peter (who ran those experiments) and I'll tell you when he gets back to me.

I think it was option 2 but I just emailed Peter (who ran those experiments) and I'll tell you when he gets back to me.

Thanks!

I haven't experimented with this too much but I think that if you have a model that was trained with sinusoidal or learned embeddings you would have to retrain it from scratch if you want to use ALiBi. It would be interesting to experiment with just finetuning with ALiBi, I have no idea if that would work or not. If you do end up trying the finetuning method tell me how it goes, I'm curios!

FYI, changing the position embeddings of BERT and then finetuning seems to work for On Position Embeddings in BERT so it may be doable for ALiBi. They didn't quite get a better model in the end though and it's unclear whether training from scratch would work better.

Hi @EIFY and @ofirpress , I implemented and tested the first option (symmetric) and pre-trained BERT from scratch.

Thanks @peteriz! I've also heard from others that option 2 works well, so I would try both and see what leads to better results.

@EIFY - I am not aware of any results showing that it is possible to apply ALiBi to a model that wasn't trained with it. I think its a super interesting question and I'm curious to see if it could be made possible.

I am experimenting with the following asymmetrical implementation, which uses different offsets for the linear bias forward & backward:
https://github.com/EIFY/fairseq/blob/fc1fabc8612cd25cf3e15a5623ebddd59f1219bd/fairseq/utils.py#L968-L973
My rationale is that if the model can discern a difference of 1 * slopes in linear biases, it can probably discern a difference of 0.5 * slopes. Obviously, experiments will have the final say though 😅

Hi @ofirpress ,

I'm trying to use ALiBi for machine translation.
It works well if I only apply ALiBi in the decoder (BLEU is about 55). But the performance gets worse if I apply ALiBi in the encoder and cross-attention (BLEU is about 10). I've tried the above 3 options for the encoder.
Do you have any suggestions?

Thank you in advance!

Hi @lyc2001:
Try applying ALiBi in the encoder with variation 1 or 2. Use decoder alibi as normal. Don't apply any kind of bias or position embedding to the cross-attention. That might work. Tell me how it goes.

Also: make sure you fully are not using any kind of positional embeddings when you use ALiBi

Hi @ofirpress again,

I have tested a couple of variations of ALiBi w/ MLM, using RoBERTa as the baseline (https://github.com/EIFY/fairseq). I think you may be interested in the results 🙂

Great. I'm wondering if you also tried any of the options listed above?

Symmetrical ALiBi (option 1 above) behaves almost identically to shifted asymmetrical ALiBi in my WikiText-103 experiments.

(If you really want to know I can pick out the dotted line corresponding to it 😂)

@ofirpress Sorry but correct me if I am wrong. The positional encoding is needed just in self-attention we do not need it in cross ateention I am referring to T5 cross attention implementation.

@Arij-Aladel yes, in the original post Ofir says

The T5 model uses no positional information in cross-attention and I would recommend doing the same thing.

An extra data point: I've found that masking half the heads (option 2, asymmetric) worked well for my use case.

@Daniel-Liu-c0deb0t What kind of model are you building and what metrics are you optimizing?

People have asked me how to implement ALiBi for FIM models. Here are two ideas I have:

Screen Shot 2022-11-09 at 3 57 42 PM

  1. Proposal one is to use regular ALiBi but instead of constantly increasing the ALiBi bias, it should reset every time we enter a new chunk. See the picture (tokens in black, ALiBi biases in blue on second row). Since the prefix comes before the main text, the numbers there are descending. This is slightly unoptimal since this biasing kind of implies that the "mid" section is of length 0, but I don't think it's horrible.
  2. The second idea is to use the same biasing as before, but this time, have 1/3rd of the heads look just at the prefix, 1/3rd look just at the suffix, and 1/3rd of the heads looks just at the mid. I think this would work well, based on similar ideas working for encoder-only (bidirectional) attention.

@EIFY well my use case is pretty specialized (DNA error correction) but its a BERT-like model. In my experiments, its trained from scratch with alibi and it is very slightly better than sinusoidal absolute position encoding.

Hi @lyc2001: Try applying ALiBi in the encoder with variation 1 or 2. Use decoder alibi as normal. Don't apply any kind of bias or position embedding to the cross-attention. That might work. Tell me how it goes.

Also: make sure you fully are not using any kind of positional embeddings when you use ALiBi

Hi @ofirpress ,

Thank you! My machine translation model works well now, but I'm facing another problem. When testing on data about the same length as training, BLEU is 55, about the same as vanilla Transformer. My training data is about 10 characters per sentence on average, and I'm trying to extrapolate it on data that's about 50 characters per sentence. BLEU drops to 8. The outputs are indeed much longer than those generated by the vanilla Transformer, but they keep repeating some words and sometimes couldn't stop the sentence.

For example,

我 可以 帶 時間 好好 休息 好 休息 你們 要 聽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 去 煮飯
以後 你 一定 做 錯 是 我 怎麼 救 你 來不及 你 來不及 你 想 救 我 救 我 救 我 救 我 救 我 想 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我 救 我

Do you have any suggestions?
Thank you in advanced!

Extrapolation has its limits-- it seems like in your use case, training on 10 characters might just be too little...

Also, to make sure there's no bug, test what happens if you train on 10 and try extrapolating to length 11 or 12 at inference. There you should see the same or slightly better BLEU.

Hi @ofirpress ,
Thank you for your reply! I'll try to apply it on shorter length sentences and see what happens.

Hi there!

I'm hoping to use Alibi for T5 and was wondering if anyone could share their code. I've been having trouble locating the exact location to add the relative_position.

For a value of 12 I'm seeing a jump in the plotting of values, e.g. 0.7 shown below.

Is this expected behaviour?

[0.5,
 0.25,
 0.125,
 0.0625,
 0.03125,
 0.015625,
 0.0078125,
 0.00390625,
 0.7071067811865476,
 0.35355339059327384,
 0.17677669529663692,
 0.08838834764831849]
# ALiBi
#

b_cache = {}

def alibi_bias(n, device):
    if str(n) in b_cache:
        bias = b_cache[str(n)]
    else:
        bias = torch.zeros(n, n)
        for i in range(n):
            bias[i, :i] = -torch.arange(i, 0, -1)
        b_cache[str(n)] = bias
    bias = bias.to(device)
    return bias

def get_slopes(n):
    # ALiBi: When using ALiBi, we do not add position embeddings at any point in the network. The only
    # modification we apply is after the query-key dot product, where we add a static, non-learned bias:
    # softmax(q(i)K.T + m @ [-(i - 1), ..., -2, -1, 0]),
    # where scalar `m` is a head-specific slope fixed before training.
    def get_slopes_power_of_2(n):
        start = (2**(-2**-(math.log2(n)-3)))
        ratio = start
        return [start*ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)                   # In the paper, we only train models that have 2^a heads for some a. This function has
    else:                                                 # some good properties that only occur when the input is a power of 2. To maintain that even
        closest_power_of_2 = 2**math.floor(math.log2(n))  # when the number of heads is not a power of 2, we use this workaround.
        
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

#
# #####

Has anyone tried combining ALiBi with CrossBatch from the Focused Transformer paper?