Forked from Microsoft's Differential Transformer
https://github.com/microsoft/unilm/tree/master/Diff-Transformer
by Eric Hartford
For input
where:
Key properties:
- Path[$i$] has parallel access to Path[0...$i$-1]
- Dimensionality preserved through projections
- Gradient and scale stabilized
introspective_diffattn.py
contains naive implementation of introspective attention.
introspective_flashdiff_1.py
contains introspective attention implemented with FlashAttention, for packages that support different qk/v dimensions (e.g., our customized-flash-attention and xformers). (Recommended for faster training and inference)
introspective_flashdiff_2.py
contains introspective attention implemented with FlashAttention, for packages that do not support different qk/v dimensions (e.g., flash-attention).
We recommend using models with a sufficiently large number of heads to minimize the impact of halving heads. For instance, using Diff Transformer with more than 8 heads (the minimum used in the paper, with the same number of parameters as Transformer with 16 heads) is advisable.
This project is licensed under the MIT License - see the LICENSE file for details