This project implements Adaptive Sparse Causal Attention (ASCA) for character-level language modeling on the enwik8 dataset. The research explores enhancing computational efficiency and performance by introducing content-based adaptive sparsity into attention mechanisms.
- Introduction
- Key Features
- Implementation Details
- Results and Discussion
- Installation
- Usage
- Project Structure
- References
- License
The enwik8 dataset is a widely recognized benchmark for evaluating compression capabilities in character-level language modeling. This project focuses on improving model efficiency and performance through a novel modification—Adaptive Sparse Causal Attention.
- Develop a character-level language model that captures sequential dependencies in data efficiently.
- Introduce a content-based sparsity mechanism to reduce computational burden.
- Demonstrate improved bits-per-character (BPC) performance over baseline transformer models.
- Implementation of Content-Based Adaptive Sparsity Module to dynamically prune attention weights.
- Integration with Causal Self-Attention while preserving autoregressive properties.
- Evaluation of model performance using BPC and sparsity metrics on the enwik8 dataset.
This module predicts and applies a sparsity mask to prune irrelevant attention connections dynamically. Key features include:
- Content-based mask prediction via a small neural network.
- Dynamic adjustment of sparsity during training.
- Reduced computational overhead with efficient pruning.
class ContentBasedAdaptiveSparsity(nn.Module):
def forward(self, x, att):
# Predict and apply sparsity mask
...
return sparse_att
This module integrates adaptive sparsity into a standard causal self-attention mechanism:
- Ensures autoregressive properties with causal masking.
- Dynamically prunes attention connections based on input content.
class AdaptiveSparseCausalSelfAttention(nn.Module):
def forward(self, x, layer_past=None):
# Apply causal attention with adaptive sparsity
...
return y
- Baseline Transformer Model: BPC ~5.53 on enwik8 test set.
- Adaptive Sparse Causal Attention Model: BPC ~4.26, demonstrating better text compression with reduced computational cost.
- Average sparsity: ~30% during training.
- Visualization shows dynamic adjustment of attention weights to retain critical connections.
- Improved BPC performance.
- Significant reduction in computational complexity.
-
Clone the repository:
git clone https://github.com/Itssshikhar/SAEs.git cd SAEs
-
Install dependencies:
pip install -r requirements.txt
-
Prepare the enwik8 dataset:
python prepare_dataset.py
python baseline_model.py
python novel_model.py
Use provided scripts to evaluate BPC and visualize sparsity.
baseline_model.py
: Baseline transformer implementation.novel_model.py
: Adaptive Sparse Causal Attention implementation.prepare_dataset.py
: Script for preparing the enwik8 dataset.configurator.py
: Configurations for models and training.extraction.py
: Data extraction tools.*.png
&*.pdf
: Visualizations of sparsity and performance.
This project is licensed under the MIT License.