A clean, from-scratch PyTorch implementation of Image-based Joint-Embedding Predictive Architecture (I-JEPA) from Assran et al., CVPR 2023, trained on STL-10 and benchmarked directly against a companion MAE implementation using the same ViT-Base/16 backbone. No timm, no pre-trained weights — every component including the ViT encoder is built from the ground up.
The central question: does predicting in representation space produce better features than predicting in pixel space?
I-JEPA linear probe: 78.97% | MAE full fine-tune: 72.66%
I-JEPA wins with a frozen encoder. MAE requires updating all weights to reach its score.
Both models are trained under identical conditions:
- Same backbone — ViT-Base/16 at 224×224
- Same dataset — STL-10 (100k unlabeled for pre-training, 5k labeled for evaluation)
- Same number of pre-training epochs — 50
- Same optimizer — AdamW with β=(0.9, 0.95)
- No hand-crafted view augmentations during pre-training in either case
The one intentional difference is the evaluation protocol. I-JEPA is evaluated with a frozen linear probe — only a single linear layer is trained; the encoder never changes. MAE uses full fine-tuning — all encoder weights are updated on the labeled data. A linear probe is the harder test: it cannot compensate for weaknesses in the representation. Winning under the harder protocol is a direct measure of representation quality.
| MAE | I-JEPA | |
|---|---|---|
| Loss target | Pixel values (L1) | EMA representations (MSE) |
| Pre-training loss — epoch 1 | 0.43 | 0.44 |
| Pre-training loss — epoch 50 | 0.21 | 0.32 |
| Masking | 75% random patches | 4 target blocks (15–20% scale) |
| Augmentations during pre-training | None | None |
| Method | Evaluation protocol | Epochs | Accuracy | Macro F1 |
|---|---|---|---|---|
| Random baseline | Full fine-tune | 15 | 46.19% | — |
| MAE | Full fine-tune | 15 | 72.66% | — |
| I-JEPA (this repo) | Linear probe (frozen encoder) | 15 | 77.58% | 77.42% |
Linear probe = encoder completely frozen, only a
Linear(768, 10)head is trained.
Full fine-tune = entire encoder updated during downstream training.
I-JEPA has three components that are active during pre-training and one that survives into evaluation.
Single image (B, 3, 224, 224)
│
├─ MultiBlockMasking ──────────────────────────────────────────────────────
│ ├─ 1 context block (scale 0.85–1.0, fixed aspect ratio)
│ └─ 4 target blocks (scale 0.15–0.20, random aspect 0.75–1.5)
│ └─ target patches removed from context → no overlap guaranteed
│
├─ Context Encoder (ViT-Base, depth=12, dim=768 — trained by gradients)
│ └─ sees only context patches → list of (Nc, 768) per sample
│
├─ Target Encoder (same ViT-Base — updated by EMA only, never by gradients)
│ └─ sees full image → (196, 768) per sample → per-token LayerNorm
│
└─ Predictor (narrow ViT, depth=6, dim=384 — trained by gradients)
input: [projected context tokens | mask_token + positional_embed[target_idx]]
output: predicted (Nt, 768) per target block
│ │
MSE loss vs target encoder output ──┘
Backprop: context encoder + predictor
EMA: θ_target ← m·θ_target + (1−m)·θ_context (m: 0.996 → 1.0)
─── After pre-training: context encoder and predictor are discarded ───
─── Only the target encoder is kept for downstream tasks ───────────────
The predictor is intentionally narrow (384-d vs 768-d encoder). This forces it to compress context information rather than copy it, and prevents the trivial shortcut of passing context patches directly to the output.
The target encoder's weights are never directly optimised by the loss. Everything it knows was accumulated through the EMA process, which gives it a stable and unbiased view of the data.
The structured multi-block masking is a core design choice that drives the quality of the learned representations. Random masking (MAE-style) creates easy local prediction tasks that can be solved with texture interpolation. I-JEPA's target blocks are large enough to cover object parts, so the model must reason about object structure to predict them.
Green = context region · Colours = the 4 target blocks to predict · White = excluded from both
What does the predictor actually learn? After training, we can run the predictor on held-out images and retrieve the gallery patch whose target-encoder feature is closest to the predicted feature — a nearest-neighbour imagination in representation space.
Each row is one image. Columns left to right:
| Column | Content |
|---|---|
| Original | Unmodified image |
| Context | What the context encoder sees — targets and excluded patches greyed out |
| Target 1–4 | Context pixels + nearest-neighbour reconstruction for that one target block |
| Reconstruction | Context pixels + NN reconstruction for all four target blocks combined |
For every target patch the predictor outputs a predicted representation; the gallery patch (from 200 test images × 196 patches) whose target-encoder feature is cosine-nearest is retrieved and placed in. The predictor has no access to the target pixels — it must infer what belongs there purely from the context representation.
To generate this figure after pre-training:
python predict_and_visualize.pyt-SNE of the target encoder features on the full STL-10 test set (8,000 images, 10 classes) after 50 epochs of pre-training with no labels used:
Loss, learning rate schedule (15-epoch warmup → cosine decay), and EMA momentum annealing (0.996 → 1.0) over 50 epochs.
Confusion matrix and per-class accuracy on the 8k STL-10 test set. The frozen encoder separates all 10 classes cleanly — there is no single dominant failure mode, which is characteristic of a well-structured representation space rather than a classifier that memorised the training distribution.
git clone https://github.com/Alpsource/Visual-Representation-Learning-JEPA
cd Visual-Representation-Learning-JEPA
pip install -r requirements.txtRequirements: Python ≥ 3.10, PyTorch ≥ 2.0. Tested on a single NVIDIA Quadro A6000 (24 GB VRAM).
jupyter notebook Self_Supervised_Learning.ipynbTrains on the 100k unlabeled STL-10 images for 50 epochs (~32 hours on an A6000). STL-10 is downloaded automatically on first run. Saves the target encoder to checkpoints/. Generates training curves, masking examples, and a t-SNE plot.
jupyter notebook Fine_Tuning.ipynbLoads the pre-trained target encoder, trains a linear head on 5k labeled images with the encoder frozen, and evaluates on 8k test images. Produces a confusion matrix, per-class accuracy breakdown, and comparison figures.
python predict_and_visualize.pyLoads checkpoints/ijepa_checkpoint_ep50.pth, builds a feature gallery from 200 test images, and generates images/ijepa_predictions.png — original, context-only, per-target-block NN reconstruction, and full reconstruction shown above.
| Pre-training | Fine-tuning | |
|---|---|---|
| Backbone | ViT-Base/16, 224×224 | ViT-Base/16, 224×224 |
| Encoder | depth=12, dim=768, heads=12 | frozen |
| Predictor | depth=6, dim=384, heads=12 | — |
| Batch / effective | 64 / 256 (×4 accum) | 256 |
| Peak LR | 1e-3 | 1e-3 |
| LR schedule | 15-epoch warmup → cosine → 1e-6 | 5-epoch warmup → cosine → 1e-5 |
| Weight decay | 0.04 → 0.40 (linear) | 0.05 |
| EMA momentum | 0.996 → 1.000 (linear) | — |
| Epochs | 50 | 50 |
├── models.py # All architecture code
│ PatchEmbed · ContextEncoder · TargetEncoder
│ Predictor · IJEPA · MultiBlockMasking · LinearClassifier
│
├── Self_Supervised_Learning.ipynb # Pre-training
├── Fine_Tuning.ipynb # Linear probe + MAE comparison
├── predict_and_visualize.py # Predictor visualisation (context · per-block NN reconstruction · full reconstruction)
├── requirements.txt
└── images/ # Generated on first run
The entire architecture — patch embedding, multi-head attention, sinusoidal 2D positional embeddings, variable-length sequence batching, EMA update — is implemented from scratch in models.py using only PyTorch primitives. No timm, no pre-trained ViT weights, no external model libraries.
Three bugs that are easy to introduce and silently destroy results:
1. Attention mask convention.
torch.nn.functional.scaled_dot_product_attention interprets a boolean attn_mask as True = attend. Padding masks built during variable-length batch processing typically use True = padding. Passing the mask without inverting causes the encoder to attend to zero-padding and ignore real tokens — the model trains on garbage with no error message.
2. Loss normalization.
The MSE loss must use .mean(), not .sum(). With a batch of 64, 4 target blocks per image, ~30 patches per block, and 768 embedding dimensions, .sum() produces values around 50,000 that vary wildly with batch composition. The gradient is meaningless. .mean() gives a stable value around 1–2 at initialization.
3. LR schedule with gradient accumulation.
scheduler.step() must be called once per optimizer step, not once per batch. With ACCUM_STEPS=4, there are 390 optimizer steps per epoch (not 1562). If TOTAL_STEPS and WARMUP_STEPS are set using raw batch counts, the warmup phase (23,430 steps) is longer than the entire training run (19,500 optimizer steps). The learning rate never reaches its peak value.
| MAE | I-JEPA (this repo) | |
|---|---|---|
| Prediction target | Pixels | Representations |
| Masking | 75% random | Structured multi-block |
| Decoder / predictor | 8-layer ViT, 512-dim | 6-layer ViT, 384-dim |
| Augmentations | None | None |
| Downstream evaluation | Full fine-tune | Linear probe |
| STL-10 accuracy | 72.66% | 78.97% |
@inproceedings{assran2023ijepa,
title = {Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
author = {Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and
Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2023}
}





