Skip to content

A simple implementation of model distillation using PyTorch on the CIFAR-10 dataset.

ajitashwath/model-distillation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

6 Commits
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🧠 Model Distillation: Using CIFAR-10 dataset

πŸ” What is Model Distillation?

Think of model distillation like file compression, but for AI models. Just like you can zip a big folder to save space, we shrink a large neural network (called the β€œTeacher”) into a much smaller one (called the β€œStudent”) while trying to keep it smart.

This is especially useful when you want to run models on phones, edge devices, or any low-resource environments.


πŸ“‹ Abstract

This repository presents a comprehensive implementation and analysis of knowledge distillation for image classification on CIFAR-10. We demonstrate how a lightweight student network can learn from a larger teacher network, achieving significant computational efficiency gains while maintaining competitive performance. Our implementation includes detailed performance metrics, model analysis, and visualization tools for understanding the distillation process.

🎯 Research Objectives

  1. Investigate the effectiveness of knowledge distillation for model compression
  2. Analyze the trade-offs between model size, inference speed, and accuracy
  3. Provide a reproducible framework for knowledge distillation experiments
  4. Demonstrate practical applications of model compression techniques

πŸ—οΈ Architecture Overview

Teacher Model (ResNet18)

  • Architecture: Deep convolutional network with residual connections
  • Parameters: 4,664,970 (~4.7M)
  • Size: 17.81 MB
  • Design Philosophy: Maximizes representational capacity for knowledge extraction

Student Model (Lightweight CNN)

  • Architecture: Compact 4-block convolutional network
  • Parameters: 106,826 (~107K)
  • Size: 0.41 MB
  • Compression Ratio: 43.7Γ— parameter reduction, 43.4Γ— size reduction

πŸ“Š Experimental Results

Performance Metrics

Model Accuracy (%) Size (MB) Parameters Inference Time (ms) Efficiency Score*
Teacher 83.13 17.81 4,664,970 111.29 0.75
Student Baseline 72.70 0.41 106,826 27.76 2.62
Student Distilled 70.86 0.41 106,826 25.82 2.74

*Efficiency Score = Accuracy / (Size Γ— Inference Time)

Key Findings

πŸ” Surprising Result: Distillation Performance Gap

  • Expected: Distilled student outperforms baseline student
  • Observed: Distilled student (70.86%) performs 1.84% lower than baseline (72.70%)
  • Hypothesis: Potential optimization challenges or hyperparameter sensitivity

⚑ Computational Efficiency Gains

  • Size Reduction: 97.7% smaller model (17.81MB β†’ 0.41MB)
  • Speed Improvement: 4.3Γ— faster inference (111ms β†’ 26ms)
  • Parameter Efficiency: 43.7Γ— fewer parameters with only 14.8% accuracy drop

πŸ“ˆ Scaling Analysis

  • Teacher-to-student knowledge gap: 12.27% accuracy difference
  • Size-to-performance ratio: Excellent for resource-constrained environments
  • Inference efficiency: Suitable for real-time applications

πŸ§ͺ Methodology

Knowledge Distillation Framework

Our implementation follows the seminal work by Hinton et al. (2015) with the following components:

Loss Function

L_total = Ξ± Γ— L_soft + (1-Ξ±) Γ— L_hard

Where:

  • L_soft: KL divergence between teacher and student soft predictions
  • L_hard: Cross-entropy loss with ground truth labels
  • Ξ± = 0.7: Weighting factor (soft knowledge emphasis)
  • T = 4.0: Temperature parameter for probability smoothing

Training Configuration

  • Optimizer: Adam with weight decay (1e-4)
  • Learning Rate: 0.001 with Step LR scheduler
  • Batch Size: 128
  • Epochs: 15 for all models
  • Data Augmentation: Random crop, horizontal flip, normalization

Evaluation Metrics

  1. Accuracy: Top-1 classification accuracy on CIFAR-10 test set
  2. Model Size: Memory footprint including parameters and buffers
  3. Inference Time: Average forward pass time over 100 batches
  4. Parameter Count: Total trainable parameters
  5. Compression Ratio: Teacher size / Student size

πŸ”¬ Analysis and Discussion

Distillation Challenges Observed

1. Negative Transfer Phenomenon

The distilled student's underperformance compared to the baseline suggests:

  • Capacity Mismatch: Student network may be too small to effectively capture teacher knowledge
  • Temperature Sensitivity: T = 4.0 might be suboptimal for this architecture pair
  • Loss Balance: Ξ± = 0.7 may overemphasize soft targets

2. Potential Improvements

  • Hyperparameter Tuning: Systematic search for optimal T and Ξ±
  • Progressive Distillation: Multi-step knowledge transfer
  • Feature-level Distillation: Intermediate layer knowledge transfer
  • Attention Transfer: Focus on important spatial regions

Practical Implications

βœ… Success Metrics

  • Deployment Viability: 97.7% size reduction enables mobile deployment
  • Real-time Processing: 4.3Γ— speed improvement for latency-critical applications
  • Resource Efficiency: Significant reduction in computational requirements

⚠️ Limitations

  • Accuracy Trade-off: 12.27% accuracy drop may be significant for some applications
  • Distillation Effectiveness: Need for methodology refinement
  • Architecture Dependency: Results may vary with different model architectures

πŸ“ˆ Recommendations for Future Work

Immediate Improvements

  1. Hyperparameter Optimization: Grid search for T ∈ [1, 2, 3, 4, 5] and α ∈ [0.1, 0.3, 0.5, 0.7, 0.9]
  2. Student Architecture: Experiment with slightly larger student networks
  3. Training Duration: Extended training with learning rate scheduling
  4. Ensemble Distillation: Multiple teacher models for robust knowledge transfer

Advanced Techniques

  1. Attention-based Distillation: Transfer spatial attention maps
  2. Progressive Knowledge Transfer: Curriculum learning approach
  3. Multi-level Feature Distillation: Intermediate layer supervision
  4. Adversarial Distillation: GAN-based knowledge transfer

πŸš€ Usage Instructions

Prerequisites

pip install torch torchvision matplotlib pandas numpy

πŸ“Š Visualization and Analysis

The implementation includes comprehensive visualization:

  • Accuracy Comparison: Bar charts showing model performance
  • Size Analysis: Model compression visualization
  • Training Curves: Learning progression over epochs
  • Efficiency Metrics: Speed vs. accuracy trade-offs

πŸ”— References

  1. Hinton et al. (2015) - Distilling the Knowledge in a Neural Network
  2. Distillation in PyTorch (Tutorial)

🀝 Contributing

Contributions are welcome! Please feel free to submit pull requests or open issues for:

  • Bug fixes
  • Performance improvements
  • New distillation techniques
  • Additional evaluation metrics
  • Documentation improvements

For questions, suggestions, or collaborations, please open an issue or reach out through the repository's discussion section.


This research demonstrates the practical viability of knowledge distillation for model compression while highlighting areas for methodological improvement. The significant computational gains achieved make this approach valuable for deployment in resource-constrained environments, despite the observed accuracy trade-offs.

About

A simple implementation of model distillation using PyTorch on the CIFAR-10 dataset.

Topics

Resources

Stars

Watchers

Forks