3500% faster CKA computation across all layers of two distinct ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs
- ⚡️ Fastest among CKA libraries thanks to vectorized ops & GPU acceleration
- 📦 Efficient memory management with explicit deallocation
- 🧠 Supports HuggingFace models, DataParallel, and DDP
- 🎨 Customizable visualizations: heatmaps and line charts
Requires Python 3.10+
# Using pip
pip install pytorch-cka
# Using uv
uv add pytorch-ckafrom torch.utils.data import DataLoader
from cka import CKA
pretrained_model = ... # e.g. pretrained ResNet-18
fine_tuned_model = ... # e.g. fine-tuned ResNet-18
layers = ["layer1", "layer2", "layer3", "fc"]
dataloader = DataLoader(..., batch_size=128)
cka = CKA(
model1=pretrained_model,
model2=fine_tuned_model,
model1_name="ResNet-18 (pretrained)",
model2_name="ResNet-18 (fine-tuned)",
model1_layers=layers,
model2_layers=layers,
device="cuda"
)
# Most convenient usage (auto context manager)
cka_matrix = cka(dataloader)
cka_result = cka.export(cka_matrix)
# Or explicit control
with cka:
cka_matrix = cka.compare(dataloader)
cka_result = cka.export(cka_matrix)Heatmap
from cka import plot_cka_heatmap
fig, ax = plot_cka_heatmap(
cka_matrix,
layers1=layers,
layers2=layers,
model1_name="ResNet-18 (pretrained)",
model2_name="ResNet-18 (random init)",
annot=False, # Show values in cells
cmap="inferno", # Colormap
mask_upper=False, # Mask upper triangle (symmetric matrices)
)![]() |
![]() |
| Self-comparison | Cross-model |
Trend Plot
from cka import plot_cka_trend
# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
fig, ax = plot_cka_trend(
diagonal,
labels=["Self-similarity"],
xlabel="Layer",
ylabel="CKA Score",
)![]() |
![]() |
| Cross Model CKA Scores Trends | Multiple Trends |
Kornblith, Simon, et al. "Similarity of Neural Network Representations Revisited." ICML 2019.



