Skip to content

The Fastest, Memory-efficient Python Library for CKA with Built-in Visualization

License

Notifications You must be signed in to change notification settings

ryusudol/Centered-Kernel-Alignment

Repository files navigation

pytorch-cka

PyPI Python PyPI Downloads

The Fastest, Memory-efficient Python Library for CKA with Built-in Visualization

A bar chart with benchmark results in dark mode

3500% faster CKA computation across all layers of two distinct ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs

  • ⚡️ Fastest among CKA libraries thanks to vectorized ops & GPU acceleration
  • 📦 Efficient memory management with explicit deallocation
  • 🧠 Supports HuggingFace models, DataParallel, and DDP
  • 🎨 Customizable visualizations: heatmaps and line charts

📦 Installation

Requires Python 3.10+

# Using pip
pip install pytorch-cka

# Using uv
uv add pytorch-cka

👟 Quick Start

Basic Usage

from torch.utils.data import DataLoader
from cka import CKA

pretrained_model = ...  # e.g. pretrained ResNet-18
fine_tuned_model = ...  # e.g. fine-tuned ResNet-18

layers = ["layer1", "layer2", "layer3", "fc"]

dataloader = DataLoader(..., batch_size=128)

cka = CKA(
    model1=pretrained_model,
    model2=fine_tuned_model,
    model1_name="ResNet-18 (pretrained)",
    model2_name="ResNet-18 (fine-tuned)",
    model1_layers=layers,
    model2_layers=layers,
    device="cuda"
)

# Most convenient usage (auto context manager)
cka_matrix = cka(dataloader)
cka_result = cka.export(cka_matrix)

# Or explicit control
with cka:
    cka_matrix = cka.compare(dataloader)
    cka_result = cka.export(cka_matrix)

Visualization

Heatmap

from cka import plot_cka_heatmap

fig, ax = plot_cka_heatmap(
    cka_matrix,
    layers1=layers,
    layers2=layers,
    model1_name="ResNet-18 (pretrained)",
    model2_name="ResNet-18 (random init)",
    annot=False,          # Show values in cells
    cmap="inferno",       # Colormap
    mask_upper=False,     # Mask upper triangle (symmetric matrices)
)
Self-comparison heatmap Cross-model comparison heatmap
Self-comparison Cross-model

Trend Plot

from cka import plot_cka_trend

# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
fig, ax = plot_cka_trend(
    diagonal,
    labels=["Self-similarity"],
    xlabel="Layer",
    ylabel="CKA Score",
)
Cross model CKA scores trends Multiple trends comparison
Cross Model CKA Scores Trends Multiple Trends

📚 References

Kornblith, Simon, et al. "Similarity of Neural Network Representations Revisited." ICML 2019.

About

The Fastest, Memory-efficient Python Library for CKA with Built-in Visualization

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors 3

  •  
  •  
  •