-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modifying ALiBi for Encoder-Attention or Cross-Attention #5
Comments
Hello @ofirpress, I have a 2 questions regarding moving from positional embedding to alibi. Assuming we are using positional embedding of max 512 tokens. When using alibi, you create the bais matrix "on the fly" on each forward pass to support different sequence lengths? or you creating the bais matrix once and just pass it to the transformer ? the second option restrict the transformer to forward only sequences of length = maybe to define alibi once with
the second question is: Thanks in advance |
Right now that's what we do, we create the ALiBi tensor once and add it to the mask. In case the current sequence length is shorter than maxpos then the mask is just cut down to the right dimensions. If you have sequences of different lengths, there's always going to be a maxpos which is the maximum that you expect your model to ever see, just make ALiBi that size and cut it down when the sequences are shorter.
I haven't experimented with this too much but I think that if you have a model that was trained with sinusoidal or learned embeddings you would have to retrain it from scratch if you want to use ALiBi. It would be interesting to experiment with just finetuning with ALiBi, I have no idea if that would work or not. If you do end up trying the finetuning method tell me how it goes, I'm curios! |
like that ?
I will ! do you know about pretrained models with ALiBi? |
Yup the code you posted looks good. |
ALiBi: Attention with Linear Biases (Press, 2021). Three different types of ALiBi are available: symmetrical, non-symmetrical with mask, and non-symmetrical. For more detail please refer to ofirpress/attention_with_linear_biases#5
Hi @ofirpress, I just saw this issue and the paper you co-authored, Transformer Language Models without Positional Encodings Still Learn Positional Information. So, which of the 3 options did you go with for the MLM ALiBi experiment (Table 4)? 😉 |
I think it was option 2 but I just emailed Peter (who ran those experiments) and I'll tell you when he gets back to me. |
Thanks!
FYI, changing the position embeddings of BERT and then finetuning seems to work for On Position Embeddings in BERT so it may be doable for ALiBi. They didn't quite get a better model in the end though and it's unclear whether training from scratch would work better. |
Hi @EIFY and @ofirpress , I implemented and tested the first option (symmetric) and pre-trained BERT from scratch. |
Thanks @peteriz! I've also heard from others that option 2 works well, so I would try both and see what leads to better results. @EIFY - I am not aware of any results showing that it is possible to apply ALiBi to a model that wasn't trained with it. I think its a super interesting question and I'm curious to see if it could be made possible. |
I am experimenting with the following asymmetrical implementation, which uses different offsets for the linear bias forward & backward: |
Hi @ofirpress , I'm trying to use ALiBi for machine translation. Thank you in advance! |
Hi @lyc2001: Also: make sure you fully are not using any kind of positional embeddings when you use ALiBi |
Hi @ofirpress again, I have tested a couple of variations of ALiBi w/ MLM, using RoBERTa as the baseline (https://github.com/EIFY/fairseq). I think you may be interested in the results 🙂 |
Great. I'm wondering if you also tried any of the options listed above? |
Symmetrical ALiBi (option 1 above) behaves almost identically to shifted asymmetrical ALiBi in my WikiText-103 experiments. (If you really want to know I can pick out the dotted line corresponding to it 😂) |
@ofirpress Sorry but correct me if I am wrong. The positional encoding is needed just in self-attention we do not need it in cross ateention I am referring to T5 cross attention implementation. |
@Arij-Aladel yes, in the original post Ofir says
|
An extra data point: I've found that masking half the heads (option 2, asymmetric) worked well for my use case. |
@Daniel-Liu-c0deb0t What kind of model are you building and what metrics are you optimizing? |
People have asked me how to implement ALiBi for FIM models. Here are two ideas I have:
|
@EIFY well my use case is pretty specialized (DNA error correction) but its a BERT-like model. In my experiments, its trained from scratch with alibi and it is very slightly better than sinusoidal absolute position encoding. |
Hi @ofirpress , Thank you! My machine translation model works well now, but I'm facing another problem. When testing on data about the same length as training, BLEU is 55, about the same as vanilla Transformer. My training data is about 10 characters per sentence on average, and I'm trying to extrapolate it on data that's about 50 characters per sentence. BLEU drops to 8. The outputs are indeed much longer than those generated by the vanilla Transformer, but they keep repeating some words and sometimes couldn't stop the sentence. For example, 我 可以 帶 時間 好好 休息 好 休息 你們 要 聽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 媽媽 去 煮飯 Do you have any suggestions? |
Extrapolation has its limits-- it seems like in your use case, training on 10 characters might just be too little... Also, to make sure there's no bug, test what happens if you train on 10 and try extrapolating to length 11 or 12 at inference. There you should see the same or slightly better BLEU. |
Hi @ofirpress , |
Hi there! I'm hoping to use Alibi for T5 and was wondering if anyone could share their code. I've been having trouble locating the exact location to add the relative_position. |
For a value of Is this expected behaviour?
|
Using the symmetric solution from ofirpress/attention_with_linear_biases#5.
Has anyone tried combining ALiBi with CrossBatch from the Focused Transformer paper? |
Hi @EIFY, I have been using your implementation of fairseq, and I had the following question:
|
Hi @VarunGumma, it has been a while and I didn't do much work with |
In our paper we only showed results on causal language models, which use causally masked (decoder) self-attention.
If you'd like to use ALiBi for seq2seq tasks such as translation, speech or T5, or if you'd like to use ALiBi for masked language models such as BERT, some modifications are required.
Encoder-Attention
Encoder-Attention is the non-masked self-attention that is performed in the encoder of seq2seq models such as translation models or T5. This is also the same kind of attention used in MLM models such as BERT.
We can't naively copy paste the ALiBi code for these models because it won't work. We use a trick to quickly calculate the bias matrix for causal language modeling, but this bias matrix is only correct for values in or below the main diagonal (since that's all that's used in causal language modeling).
This code correctly generates the full bias matrix. Note that the bias matrix is symmetric around the diagonal, since it computes the absolute distance between the query and key (so all distances are positive).
We're also going to need the code for generating the ALiBi slopes:
There are 3 options for implementing encoder-attention ALiBi:
Now just pass self.alibi to the attention function and add it after the query*key computation.
In fairseq for example, the query*key computation is done as such:
attn_weights = torch.bmm(q, k.transpose(1, 2))
, and then to add the ALiBi values use:Note: This code hasn't been fully tested yet and might contain bugs.
Again, as before, add self.alibi to the attn-weights, but this time also add the nonsym_mask tensor. (In fairseq
attn_weights += nonsym_mask[:,:,:tgt_len,:src_len].to(attn_weights)
)Note: I haven't tested this code so it might contain bugs!
Cross-Attention
For translation models and models like T5 you will also need to implement cross-attention, which is the attention from the decoder to the encoder. The T5 model uses no positional information in cross-attention and I would recommend doing the same thing.
Implementations
NEW: lucidrains/x-transformers#88 lucidrains has implemented some of the above ideas in the x-transformers repo.
The text was updated successfully, but these errors were encountered: