Generalized Optimal Transport Attention with Trainable Priors (GOAT), available as a PyTorch multi-head attention module.
Install name:
goat-attention(PyPI) · Import name:goat
- From PyPI (recommended):
uv add goat-attention- pip:
pip install goat-attention- From source (editable):
uv pip install -e .- From source (editable, pip):
pip install -e .import torch
from goat import GoatAttention
B, L, S, E, H = 2, 5, 7, 64, 8
xq = torch.randn(B, L, E)
xk = torch.randn(B, S, E)
xv = torch.randn(B, S, E)
attn = GoatAttention(
embed_dim=E,
num_heads=H,
batch_first=True,
pos_rank=2,
abs_rank=4,
enable_key_bias=True,
)
out, weights = attn(xq, xk, xv, is_causal=False, need_weights=True)
print(out.shape, None if weights is None else weights.shape)GOAT uses spectral (Fourier) features to model positional relationships. The two main hyperparameters are:
| Parameter | Description | Typical Range |
|---|---|---|
pos_rank |
Number of Fourier frequencies for relative position encoding. Controls how finely the model can distinguish between different relative distances. Higher values = more expressive distance modeling. | 2–16 |
abs_rank |
Number of Fourier frequencies for absolute position encoding, used by the learned "sink" bias u(j). Controls position-dependent attention patterns (e.g., attending to sequence start). |
2–8 |
Guidelines:
- For language models (GPT-style):
pos_rank=4, abs_rank=4is a good starting point - For vision transformers (ViT):
pos_rank=16, abs_rank=2works well with 2D positional encoding - For long-context tasks: Higher
pos_rankmay help capture fine-grained positional patterns - Set
enable_key_bias=Falseto disable the sink term entirely (removesabs_rankdependency)
After installation:
goat info
goat smokeSee docs/:
uv pip install -e ".[dev]"
pytestMIT (see LICENSE).
If you find GOAT useful, please cite:
@misc{goat,
title = {You Need Better Attention Priors},
author = {Elon Litman and Gabe Guo},
year = {2026},
eprint = {2601.15380},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}