This repository implements a simple Multi-Head Attention Mechanism in PyTorch. It supports masking for use cases such as padding token exclusion and causal attention, making it suitable for various NLP and sequence-based tasks.
- Multi-Head Attention: Implementation of the core component of transformer-based models.
- Masking Support: Includes masking for padding and causal attention.
- Scalable Design: Configurable number of heads and embedding dimensions.
- Reproducible Results: Controlled parameter initialization with fixed seeds.
- Clone the repository:
git clone https://github.com/your-username/multi-head-attention-pytorch.git
- Navigate to the directory:
cd multi-head-attention-pytorch
- Install dependencies:
pip install torch
import torch from mhma import MHMA
embed_dim = 128 # Embedding dimension heads = 8 # Number of attention heads dq = 16 # Query dimension per head dk = 16 # Key dimension per head dv = 16 # Value dimension per head
output = model(x) print(output.shape) # Expected: (32, 10, 128)
seq_len = 10 mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(1) # Causal mask output = model(x, mask=mask)
The MHMA class implements the following:
- Query, Key, Value Matrices: Learned projections for multi-head attention.
- Attention Scores: Scaled dot-product attention computation.
- Masking: Optional masking to exclude certain positions.
- Output Projections: Combination of attention outputs from multiple heads.
multi-head-attention-pytorch/
- │
- ├── mhma.py # Core Multi-Head Attention implementation
- ├── README.md # Project documentation
- └── requirements.txt # Dependencies (if any)
Contributions are welcome! Feel free to open an issue or submit a pull request.