Skip to content

Alpsource/Visual-Representation-Learning-JEPA

Repository files navigation

Self-Supervised Learning with I-JEPA

A clean, from-scratch PyTorch implementation of Image-based Joint-Embedding Predictive Architecture (I-JEPA) from Assran et al., CVPR 2023, trained on STL-10 and benchmarked directly against a companion MAE implementation using the same ViT-Base/16 backbone. No timm, no pre-trained weights — every component including the ViT encoder is built from the ground up.

The central question: does predicting in representation space produce better features than predicting in pixel space?

I-JEPA linear probe: 78.97%  |  MAE full fine-tune: 72.66%
I-JEPA wins with a frozen encoder. MAE requires updating all weights to reach its score.


What makes this comparison meaningful

Both models are trained under identical conditions:

  • Same backbone — ViT-Base/16 at 224×224
  • Same dataset — STL-10 (100k unlabeled for pre-training, 5k labeled for evaluation)
  • Same number of pre-training epochs — 50
  • Same optimizer — AdamW with β=(0.9, 0.95)
  • No hand-crafted view augmentations during pre-training in either case

The one intentional difference is the evaluation protocol. I-JEPA is evaluated with a frozen linear probe — only a single linear layer is trained; the encoder never changes. MAE uses full fine-tuning — all encoder weights are updated on the labeled data. A linear probe is the harder test: it cannot compensate for weaknesses in the representation. Winning under the harder protocol is a direct measure of representation quality.


Results

Pre-training (50 epochs, STL-10 unlabeled, 100k images)

MAE I-JEPA
Loss target Pixel values (L1) EMA representations (MSE)
Pre-training loss — epoch 1 0.43 0.44
Pre-training loss — epoch 50 0.21 0.32
Masking 75% random patches 4 target blocks (15–20% scale)
Augmentations during pre-training None None

Downstream classification (STL-10, 5k labeled, 8k test)

Method Evaluation protocol Epochs Accuracy Macro F1
Random baseline Full fine-tune 15 46.19%
MAE Full fine-tune 15 72.66%
I-JEPA (this repo) Linear probe (frozen encoder) 15 77.58% 77.42%

Linear probe = encoder completely frozen, only a Linear(768, 10) head is trained.
Full fine-tune = entire encoder updated during downstream training.


Architecture

I-JEPA has three components that are active during pre-training and one that survives into evaluation.

Single image  (B, 3, 224, 224)
  │
  ├─ MultiBlockMasking ──────────────────────────────────────────────────────
  │    ├─ 1 context block  (scale 0.85–1.0, fixed aspect ratio)
  │    └─ 4 target blocks  (scale 0.15–0.20, random aspect 0.75–1.5)
  │         └─ target patches removed from context → no overlap guaranteed
  │
  ├─ Context Encoder  (ViT-Base, depth=12, dim=768 — trained by gradients)
  │    └─ sees only context patches → list of (Nc, 768) per sample
  │
  ├─ Target Encoder  (same ViT-Base — updated by EMA only, never by gradients)
  │    └─ sees full image → (196, 768) per sample → per-token LayerNorm
  │
  └─ Predictor  (narrow ViT, depth=6, dim=384 — trained by gradients)
       input:  [projected context tokens | mask_token + positional_embed[target_idx]]
       output: predicted (Nt, 768) per target block
                       │                          │
               MSE loss vs target encoder output ──┘

  Backprop:  context encoder + predictor
  EMA:       θ_target ← m·θ_target + (1−m)·θ_context   (m: 0.996 → 1.0)

  ─── After pre-training: context encoder and predictor are discarded ───
  ─── Only the target encoder is kept for downstream tasks ───────────────

The predictor is intentionally narrow (384-d vs 768-d encoder). This forces it to compress context information rather than copy it, and prevents the trivial shortcut of passing context patches directly to the output.

The target encoder's weights are never directly optimised by the loss. Everything it knows was accumulated through the EMA process, which gives it a stable and unbiased view of the data.


Masking strategy

The structured multi-block masking is a core design choice that drives the quality of the learned representations. Random masking (MAE-style) creates easy local prediction tasks that can be solved with texture interpolation. I-JEPA's target blocks are large enough to cover object parts, so the model must reason about object structure to predict them.

I-JEPA masking strategy

Green = context region  ·  Colours = the 4 target blocks to predict  ·  White = excluded from both


Predictor visualisation

What does the predictor actually learn? After training, we can run the predictor on held-out images and retrieve the gallery patch whose target-encoder feature is closest to the predicted feature — a nearest-neighbour imagination in representation space.

I-JEPA predictor visualisation

Each row is one image. Columns left to right:

Column Content
Original Unmodified image
Context What the context encoder sees — targets and excluded patches greyed out
Target 1–4 Context pixels + nearest-neighbour reconstruction for that one target block
Reconstruction Context pixels + NN reconstruction for all four target blocks combined

For every target patch the predictor outputs a predicted representation; the gallery patch (from 200 test images × 196 patches) whose target-encoder feature is cosine-nearest is retrieved and placed in. The predictor has no access to the target pixels — it must infer what belongs there purely from the context representation.

To generate this figure after pre-training:

python predict_and_visualize.py

Learned representations

t-SNE of the target encoder features on the full STL-10 test set (8,000 images, 10 classes) after 50 epochs of pre-training with no labels used:

t-SNE of I-JEPA features


Training curves

Pre-training loss, LR schedule, and EMA momentum

Loss, learning rate schedule (15-epoch warmup → cosine decay), and EMA momentum annealing (0.996 → 1.0) over 50 epochs.


Comparison with MAE

I-JEPA vs MAE fine-tuning comparison


Evaluation breakdown

Confusion matrix on STL-10 test set    Per-class accuracy

Confusion matrix and per-class accuracy on the 8k STL-10 test set. The frozen encoder separates all 10 classes cleanly — there is no single dominant failure mode, which is characteristic of a well-structured representation space rather than a classifier that memorised the training distribution.


Installation

git clone https://github.com/Alpsource/Visual-Representation-Learning-JEPA
cd Visual-Representation-Learning-JEPA
pip install -r requirements.txt

Requirements: Python ≥ 3.10, PyTorch ≥ 2.0. Tested on a single NVIDIA Quadro A6000 (24 GB VRAM).


Running

Pre-training

jupyter notebook Self_Supervised_Learning.ipynb

Trains on the 100k unlabeled STL-10 images for 50 epochs (~32 hours on an A6000). STL-10 is downloaded automatically on first run. Saves the target encoder to checkpoints/. Generates training curves, masking examples, and a t-SNE plot.

Evaluation (linear probe)

jupyter notebook Fine_Tuning.ipynb

Loads the pre-trained target encoder, trains a linear head on 5k labeled images with the encoder frozen, and evaluates on 8k test images. Produces a confusion matrix, per-class accuracy breakdown, and comparison figures.

Predictor visualisation

python predict_and_visualize.py

Loads checkpoints/ijepa_checkpoint_ep50.pth, builds a feature gallery from 200 test images, and generates images/ijepa_predictions.png — original, context-only, per-target-block NN reconstruction, and full reconstruction shown above.


Hyperparameters

Pre-training Fine-tuning
Backbone ViT-Base/16, 224×224 ViT-Base/16, 224×224
Encoder depth=12, dim=768, heads=12 frozen
Predictor depth=6, dim=384, heads=12
Batch / effective 64 / 256 (×4 accum) 256
Peak LR 1e-3 1e-3
LR schedule 15-epoch warmup → cosine → 1e-6 5-epoch warmup → cosine → 1e-5
Weight decay 0.04 → 0.40 (linear) 0.05
EMA momentum 0.996 → 1.000 (linear)
Epochs 50 50

Repository layout

├── models.py                          # All architecture code
│     PatchEmbed · ContextEncoder · TargetEncoder
│     Predictor · IJEPA · MultiBlockMasking · LinearClassifier
│
├── Self_Supervised_Learning.ipynb     # Pre-training
├── Fine_Tuning.ipynb                  # Linear probe + MAE comparison
├── predict_and_visualize.py           # Predictor visualisation (context · per-block NN reconstruction · full reconstruction)
├── requirements.txt
└── images/                            # Generated on first run

Implementation notes

The entire architecture — patch embedding, multi-head attention, sinusoidal 2D positional embeddings, variable-length sequence batching, EMA update — is implemented from scratch in models.py using only PyTorch primitives. No timm, no pre-trained ViT weights, no external model libraries.

Three bugs that are easy to introduce and silently destroy results:

1. Attention mask convention. torch.nn.functional.scaled_dot_product_attention interprets a boolean attn_mask as True = attend. Padding masks built during variable-length batch processing typically use True = padding. Passing the mask without inverting causes the encoder to attend to zero-padding and ignore real tokens — the model trains on garbage with no error message.

2. Loss normalization. The MSE loss must use .mean(), not .sum(). With a batch of 64, 4 target blocks per image, ~30 patches per block, and 768 embedding dimensions, .sum() produces values around 50,000 that vary wildly with batch composition. The gradient is meaningless. .mean() gives a stable value around 1–2 at initialization.

3. LR schedule with gradient accumulation. scheduler.step() must be called once per optimizer step, not once per batch. With ACCUM_STEPS=4, there are 390 optimizer steps per epoch (not 1562). If TOTAL_STEPS and WARMUP_STEPS are set using raw batch counts, the warmup phase (23,430 steps) is longer than the entire training run (19,500 optimizer steps). The learning rate never reaches its peak value.


Companion repository

MAE I-JEPA (this repo)
Prediction target Pixels Representations
Masking 75% random Structured multi-block
Decoder / predictor 8-layer ViT, 512-dim 6-layer ViT, 384-dim
Augmentations None None
Downstream evaluation Full fine-tune Linear probe
STL-10 accuracy 72.66% 78.97%

Reference

@inproceedings{assran2023ijepa,
  title     = {Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
  author    = {Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and
               Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2023}
}

About

A clean, from-scratch PyTorch implementation of I-JEPA trained on STL-10. Built to benchmark representation-space vs. pixel-space prediction against MAE.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors