A PyTorch-based framework for distilling knowledge from large embedding models (teacher) into smaller, more efficient models (student) using learned attention pooling mechanisms and sparse representations.
This project implements knowledge distillation where smaller student models learn to mimic the embedding similarities produced by larger teacher models. The framework supports multiple pooling strategies and evaluation approaches for information retrieval tasks.
- Teacher Model: BGE-M3 (multilingual embedding model with sparse vector support)
- Student Models: ModernBert-based architectures with custom pooling heads
- Distillation Target: Similarity matrices between query-passage pairs using MSE loss
- Pooling Strategies: Dense attention-based pooling and sparse vocabulary-based pooling
- Evaluation: Comprehensive recall@k metrics with optional reranking
- Knowledge Distillation: Transfer similarity patterns from BGE-M3 to compact models
- Dual Pooling Architectures:
ModernBertWithActivationHeadModel: Dense attention-weighted poolingModernBertWithSparseHeadModel: Sparse vocabulary-based representations
- Multi-Dataset Support: GermanQuAD, GermanDPR, and mMARCO datasets
- Advanced Evaluation: Recall@k metrics with SGLang reranker integration
- Efficient Training: Frozen backbone with trainable pooling heads only
- Production Ready: Proper model saving/loading and HuggingFace integration
# Clone the repository
git clone <repository-url>
cd embedding-distillation
# Install dependencies (requires uv)
uv installInput → Tokenizer → ModernBert (frozen) → Activation Head → Attention-Weighted Pooling → Dense Embedding
Input → Tokenizer → ModernBert (frozen) → Activation Head → Vocabulary Scatter → Sparse Embedding
- Teacher Forward: Generate sparse embeddings using BGE-M3
- Student Forward: Generate embeddings using student model + pooling head
- Similarity Computation: Calculate cosine similarities between query-passage pairs
- Loss Calculation: MSE loss between Z-score normalized similarity matrices
from trainer import AttachedPooledEmbedderTrainer
from model_definition import ModernBertWithActivationHeadModel
from data_loading import load_germandpr, passthrough_collator
from transformers import TrainingArguments
import torch
# Load models
student_model = ModernBertWithActivationHeadModel.from_pretrained("models/ettin-encoder-32m")
student_model.setup_for_training()
# Training setup
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=32,
num_train_epochs=2,
learning_rate=2e-4,
save_strategy="steps",
save_steps=128,
)
trainer = AttachedPooledEmbedderTrainer(
model=student_model,
model_tokenizer=student_tokenizer,
bge_m3_model=teacher_model,
bge_m3_tokenizer=teacher_tokenizer,
train_dataset=load_germandpr(),
data_collator=passthrough_collator,
args=training_args,
)
trainer.train()from metrics_support import benchmark_model, prepare_germanquad_for_benchmark
from data_loading import load_germanquad
# Prepare evaluation data
eval_dataset = load_germanquad(split="test")
unique_queries, unique_passages, labels_per_query = prepare_germanquad_for_benchmark(eval_dataset)
# Benchmark model
recalls = benchmark_model(
model=student_model,
tokenizer=student_tokenizer,
eval_queries=unique_queries,
eval_passages=unique_passages,
labels_per_query=labels_per_query,
k_values=[1, 5, 10, 20, 50],
batch_size=32
)
print("Recall@k results:", recalls)from model_definition import ModernBertWithActivationHeadModel
# Dense model with attention-weighted pooling
model = ModernBertWithActivationHeadModel.from_pretrained("your-model-path")
model.setup_for_training() # Freezes backbone, enables pooling head trainingForward Process:
- Compute attention scores for each token using linear head
- Apply ReLU activation to scores
- Weighted sum of hidden states using attention mask
- Normalize by sequence length
from model_definition import ModernBertWithSparseHeadModel
# Sparse model with vocabulary-based representations
model = ModernBertWithSparseHeadModel.from_pretrained("your-model-path")
model.setup_for_training()Forward Process:
- Compute activation scores for each token
- Scatter scores to vocabulary dimensions based on input token IDs
- Use
scatter_reducewith "amax" to handle token repetitions - Output:
[batch_size, vocab_size]sparse embeddings
| Dataset | Language | Type | Format |
|---|---|---|---|
| GermanQuAD | German | QA | Question-context pairs (all positive) |
| GermanDPR | German | Retrieval | Query with positive/negative/hard negative contexts |
| mMARCO | Multilingual | Retrieval | Query-positive-negative triplets |
All datasets are normalized to:
{
"query": "What is machine learning?",
"passage": "Machine learning is a subset of...",
"label": 1 # 1 for relevant, 0 for irrelevant
}from data_loading import load_germandpr, load_germanquad, load_mmarco_multilang
# Single datasets
train_data = load_germandpr()
eval_data = load_germanquad(split="test")
# Multilingual combination
multilang_data = load_mmarco_multilang(languages=["english", "german", "french"])from trainer import DetachedPooledEmbedderTrainer
# Uses separate activation head with detached gradients
trainer = DetachedPooledEmbedderTrainer(
model=backbone_model,
activation_head=separate_head,
# ... other params
)from trainer import AttachedPooledEmbedderTrainer
# Uses integrated model with built-in pooling head
trainer = AttachedPooledEmbedderTrainer(
model=integrated_model, # ModernBertWith*HeadModel
# ... other params
)from metrics_support import benchmark_model, sglang_reranker_fn
# Basic evaluation
recalls = benchmark_model(
model=model,
tokenizer=tokenizer,
eval_queries=queries,
eval_passages=passages,
labels_per_query=labels,
k_values=[1, 5, 10, 20, 50, 100]
)
# With reranking
recalls_reranked = benchmark_model(
# ... same params ...
rerank_fn=lambda q, p: sglang_reranker_fn(q, p, base_url="http://localhost:30000/v1"),
rerank_k=50
)from trainer import RecallEvaluationCallback
callback = RecallEvaluationCallback(
model_tokenizer=tokenizer,
eval_queries=eval_queries,
eval_passages=eval_passages,
labels_per_query=labels_per_query,
k_values=[1, 5, 10, 20, 50],
batch_size=32
)
trainer = AttachedPooledEmbedderTrainer(
# ... trainer params ...
callbacks=[callback]
)| Parameter | Default | Description |
|---|---|---|
query_max_length |
512 | Maximum tokens for queries |
passage_max_length |
2048 | Maximum tokens for passages |
query_pad_to |
16 | Padding multiple for queries |
passage_pad_to |
64 | Padding multiple for passages |
batch_size |
32 | Encoding batch size |
learning_rate |
2e-4 | Learning rate for pooling head |
num_train_epochs |
2 | Number of training epochs |
# Teacher model setup
teacher = BGEM3FlagModel(
"models/bge-m3",
pooling_method="mean",
# use_fp16=True # Optional for speed
)
# Student model setup
student = ModernBertWithActivationHeadModel.from_pretrained("models/ettin-encoder-32m")
student.setup_for_training() # Freeze backbone, enable head training├── data_loading.py # Dataset loaders and preprocessors
├── model_definition.py # Custom model architectures
├── model_support.py # Encoding utilities for different model types
├── trainer.py # Custom trainer classes and callbacks
├── metrics_support.py # Evaluation metrics and benchmarking
├── main.py # Example usage and testing
├── pyproject.toml # Dependencies and project config
└── README.md # This file
data_loading.py: Pure functions for dataset loading and batch processingmodel_definition.py: HuggingFace-compatible model classes with custom poolingmodel_support.py: Composable encoding functions for different model architecturestrainer.py: Training logic with similarity-based distillation lossmetrics_support.py: Evaluation pipeline with recall@k and reranking support
Z-score normalized similarity matching:
# Teacher similarities (frozen BGE-M3 sparse vectors)
sim_teacher = query_teacher_emb @ passage_teacher_emb.T
sim_teacher = (sim_teacher - sim_teacher.mean()) / sim_teacher.std()
# Student similarities (trainable pooling head)
sim_student = query_student_emb @ passage_student_emb.T
sim_student = (sim_student - sim_student.mean()) / sim_student.std()
# MSE loss on normalized similarities
loss = F.mse_loss(sim_student, sim_teacher)from metrics_support import sglang_reranker_fn
# Requires SGLang server running on specified port
reranked_indices = sglang_reranker_fn(
queries=["query1", "query2"],
top_k_passages=[["passage1", "passage2"], ["passage3", "passage4"]],
base_url="http://localhost:30000/v1",
model="BAAI/bge-reranker-v2-m3"
)# Extend base model class
class CustomPoolingModel(ModernBertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.custom_head = nn.Sequential(
nn.Linear(config.hidden_size, 2 * config.hidden_size),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(2 * config.hidden_size, 1)
)
def setup_for_training(self):
# Freeze backbone, enable custom head
for p in self.model.parameters():
p.requires_grad = False
for p in self.custom_head.parameters():
p.requires_grad = True
return self- Mixed Precision: Enable
use_fp16=Truefor BGE-M3 - Padding Optimization: Use
pad_to_multiple_offor efficient tensor operations - Frozen Backbone: Only train pooling heads to reduce memory usage
- Batch Processing: Adjust batch size based on GPU memory
- Fast Math: Use
torch.set_float32_matmul_precision("high")
- Python 3.13+
- PyTorch 2.8+
- Transformers 4.20+
- Datasets <4.0
- FlagEmbedding 1.3.5+
- uv (package manager)
If you use this framework in your research, please cite:
@misc{embedding-distillation,
title={Embedding Model Distillation Framework with Learned Pooling},
author={Christian Bahls},
year={2025},
url={https://github.com/approximated-intelligence/embedding-distillation}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
- BGE-M3 for the teacher model and sparse embeddings
- HuggingFace Transformers for the model framework
- SGLang for efficient reranker serving
- ModernBert for the efficient transformer backbone