Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
bd0855c
Initial work on SparK
johnsutor Dec 5, 2024
42bd849
feat(spark): More work on spark
johnsutor Dec 17, 2024
7e73b83
Merge branch 'spark' of github.com:johnsutor/lightly
gabrielfruet Feb 5, 2026
724819c
refactor: should not nest modules.
gabrielfruet Feb 5, 2026
b0023f0
refactor: removed empty file
gabrielfruet Feb 5, 2026
8120990
refactor: adhered to correct directory structure.
gabrielfruet Feb 5, 2026
a7138d0
feat: put everything into the sparse spark module.
gabrielfruet Feb 5, 2026
9b869e9
refactor: removed redundant super calls with class.
gabrielfruet Feb 5, 2026
36cb94c
refactor: removed spark code.
gabrielfruet Feb 5, 2026
0afc280
refactor: removing redundant super calls
gabrielfruet Feb 5, 2026
280a479
refactor: remove empty file
gabrielfruet Feb 6, 2026
7e0f2bc
refactor: porting original code. starting from scratch
gabrielfruet Feb 6, 2026
9a9dff6
refactor: fixing type hint problems.
gabrielfruet Feb 6, 2026
407a4c0
refactor: removing unecessary redundant super
gabrielfruet Feb 6, 2026
47eecf3
fix: indentation
gabrielfruet Feb 6, 2026
7a956b3
feat: working module
gabrielfruet Feb 6, 2026
9d0f5c4
refactor: using library already implemente masking
gabrielfruet Feb 9, 2026
4aa69f7
feat: using patchify
gabrielfruet Feb 9, 2026
8b95a51
refactor: putting densification into a single module
gabrielfruet Feb 9, 2026
014f724
typo: raito -> ratio
gabrielfruet Feb 9, 2026
de01b0a
feat: encapsulated logic to single dnesifier module
gabrielfruet Feb 9, 2026
445bda7
refactor: cleaning code.
gabrielfruet Feb 9, 2026
741d531
refactor: letting sparse encoder be repsonsible for sizes and etc
gabrielfruet Feb 9, 2026
fb5c90d
feat: resnet18
gabrielfruet Feb 9, 2026
deeb4ef
refactor: removing unused code
gabrielfruet Feb 9, 2026
64df2e3
fix: bool tensor is inconvenient
gabrielfruet Feb 9, 2026
45b6eb8
refactor: documenting
gabrielfruet Feb 9, 2026
fb7903a
refactor: masking as a module
gabrielfruet Feb 9, 2026
138380f
refactor: removing unused variables
gabrielfruet Feb 9, 2026
6b3dbdf
refactor: removing unecessary module dependency
gabrielfruet Feb 9, 2026
ed8691a
refactor: loss as module
gabrielfruet Feb 9, 2026
3dcd5c3
refactor: spark visualization decoding logic as module
gabrielfruet Feb 9, 2026
1c316da
refactor: remove unused
gabrielfruet Feb 9, 2026
3ba1839
refactor
gabrielfruet Feb 9, 2026
70d32dd
refactor: removed big module and refactored timm funcs
gabrielfruet Feb 9, 2026
2f73d15
feat: example script
gabrielfruet Feb 9, 2026
b48d4eb
refactor: removed unused code and added opyrights
gabrielfruet Feb 9, 2026
25687d3
doc: improved documentation and type hinting
gabrielfruet Feb 12, 2026
96739fe
fix: type hinting
gabrielfruet Feb 12, 2026
ed70f9a
tests: testing active ex
gabrielfruet Feb 13, 2026
4fb16d7
tests: sp conv forward test
gabrielfruet Feb 13, 2026
6188aa4
refactor: using fixture instead of context manager
gabrielfruet Feb 13, 2026
7cfc508
format: formatting
gabrielfruet Feb 13, 2026
0098425
feat: moved patch recon loss to loss module
gabrielfruet Feb 14, 2026
dea0173
fix: no need of this sparse argument since its always sparse.
gabrielfruet Feb 14, 2026
83bdd60
format
gabrielfruet Feb 14, 2026
4b7f119
feat: init module access
gabrielfruet Feb 14, 2026
86386b7
feat: removed sparse encoder since it adds no necessary logic.
gabrielfruet Feb 14, 2026
92ebd99
refactor: removing dense model to sparse
gabrielfruet Feb 14, 2026
93c7ee5
fix: renamin to unet decoder
gabrielfruet Feb 14, 2026
2710f6b
format
gabrielfruet Feb 14, 2026
3a59d6a
fix: removed inplace operation
gabrielfruet Feb 24, 2026
34ed5f6
fix: using proper eps
gabrielfruet Feb 24, 2026
2ff68e9
docs: rst for loss
gabrielfruet Feb 24, 2026
de92784
feat: annotations
gabrielfruet Feb 24, 2026
ad4c9ac
typos
gabrielfruet Feb 24, 2026
8b2eca5
typing issues
liopeer Mar 2, 2026
b16c6b3
formatting
liopeer Mar 2, 2026
1b9f53a
context manager for global tensor + typing
liopeer Mar 2, 2026
4e9cde1
format
gabrielfruet Mar 2, 2026
25e7fbb
Merge branch 'feat/1462-spark-implementation' of github.com:gabrielfr…
gabrielfruet Mar 2, 2026
6abe8fc
refactor: better naming
gabrielfruet Mar 2, 2026
c9c99d2
refactor: lenght, not levels
gabrielfruet Mar 2, 2026
73da36e
fix: remove unecessary target transform
gabrielfruet Mar 2, 2026
708a18e
test: testing loss, comparing with reference and distributed testing
gabrielfruet Mar 2, 2026
0c10ea0
doc: where i took the loss from
gabrielfruet Mar 2, 2026
bbe0b07
refactor: make example simpler
gabrielfruet Mar 2, 2026
bde5292
typo: spark masking output
gabrielfruet Mar 4, 2026
73b0bb1
fix: imports according to #1895
gabrielfruet Mar 4, 2026
86ea8a9
fix: format
gabrielfruet Mar 4, 2026
edbfdcf
fix: tests now uses context manager for sparse mask
gabrielfruet Mar 4, 2026
e28391a
fix: using Tensor instead of torch.Tensor
gabrielfruet Mar 7, 2026
a1f82a0
doc: referencing original author
gabrielfruet Mar 7, 2026
9ce0f80
test: spark masking test
gabrielfruet Mar 7, 2026
e66d84a
doc: removed unecessary comment
gabrielfruet Mar 7, 2026
86934d2
feat: testing densify block
gabrielfruet Mar 7, 2026
a984e0d
fix: removed unecessary none handling
gabrielfruet Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ lightly.loss
.. autoclass:: lightly.loss.regularizer.co2.CO2Regularizer
:members:

.. autoclass:: lightly.loss.sparse_spark.SparKPatchReconLoss
:members:

.. autoclass:: lightly.loss.swav_loss.SwaVLoss
:members:

Expand Down
182 changes: 182 additions & 0 deletions examples/pytorch_lightning/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# This example requires the following dependencies to be installed:
# pip install "lightly[timm]"

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import timm
import torch
import torchvision
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Module
from torchvision.transforms import v2

import lightly.models.utils as model_utils
from lightly.loss.sparse_spark import SparKPatchReconLoss
from lightly.models.modules import sparse_spark

## The global projection head is the same as the Barlow Twins one
from lightly.models.modules.sparse_spark import (
SparKDensifier,
SparKMasker,
SparKMaskingOutput,
SparKOutputDecoder,
UNetDecoder,
sparse_layer_context,
)


def _get_downsample_ratio_from_timm_model(model: Module) -> int:
if not hasattr(model, "feature_info"):
raise ValueError(
"The provided model does not have the required 'feature_info' attribute."
)
return model.feature_info[-1]["reduction"]


def _get_enc_feat_map_chs_from_timm_model(model: Module) -> list[int]:
if not hasattr(model, "feature_info"):
raise ValueError(
"The provided model does not have the required 'feature_info' attribute."
)
return [fi["num_chs"] for fi in model.feature_info]


class SparseSparK(LightningModule):
def __init__(
self,
input_size: int = 416,
mask_ratio: float = 0.6,
densify_norm: str = "bn",
sbn=False,
):
super().__init__()
backbone = timm.create_model(
model_name="resnet18", drop_path_rate=0.05, features_only=True
)
downsample_ratio = _get_downsample_ratio_from_timm_model(backbone)
enc_feat_map_chs = _get_enc_feat_map_chs_from_timm_model(backbone)
self.sparse_encoder = sparse_spark.dense_model_to_sparse(
m=backbone, sbn=sbn, verbose=True
)
self.fmap_h = input_size // downsample_ratio
self.fmap_w = input_size // downsample_ratio
self.dense_decoder = UNetDecoder(
up_sample_ratio=downsample_ratio,
width=enc_feat_map_chs[-1],
)
self.masker = SparKMasker(
feature_map_size=(self.fmap_h, self.fmap_w),
downsample_ratio=downsample_ratio,
mask_ratio=mask_ratio,
)
self.densifier = SparKDensifier(
encoder_in_channels=enc_feat_map_chs,
decoder_in_channel=self.dense_decoder.width,
densify_norm_str=densify_norm.lower(),
sbn=sbn,
)
self.downsample_ratio = downsample_ratio
# loss module for patch reconstruction
self.recon_loss_fn = SparKPatchReconLoss()
# output decoder for visualization (pass minimal spatial props)
self.output_decoder = SparKOutputDecoder(
fmap_h=self.fmap_h,
fmap_w=self.fmap_w,
downsample_ratio=downsample_ratio,
)

def forward(
self,
inp_bchw: Tensor,
vis=False,
):
# step1. Mask
mask_out: SparKMaskingOutput = self.masker(inp_bchw)
masked_bchw, per_level_mask = mask_out
active_b1fHfW = per_level_mask[0]
active_b1hw = per_level_mask[-1]
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
# Use sparse_layer_context to provide the mask to the sparse encoder and densifier.
with sparse_layer_context(active_mask=active_b1fHfW):
fea_bcffs: list[Tensor] = self.sparse_encoder(masked_bchw)
# step3. Densify: get hierarchical dense features for decoding
to_dec = self.densifier(fea_bcffs)
# step4. Decode and reconstruct
rec_bchw = self.dense_decoder(to_dec)
inp, rec = (
model_utils.patchify(inp_bchw, self.downsample_ratio),
model_utils.patchify(rec_bchw, self.downsample_ratio),
) # inp and rec: (B, L = f*f, N = C*downsample_ratio**2)

recon_loss, mean, var = self.recon_loss_fn(
inp_patches=inp, rec_patches=rec, active_mask=active_b1fHfW
)

if vis:
return self.output_decoder(
rec_patches=rec,
mean=mean,
var=var,
inp_bchw=inp_bchw,
active_mask_full=active_b1hw,
)
else:
return recon_loss

def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Tensor:
img, _ = batch
recon_loss = self.forward(img)
# Log the training loss to logger and progress bar (per-step and per-epoch)
self.log(
"train_loss",
recon_loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return recon_loss

def configure_optimizers(self):
return torch.optim.SGD(
self.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4
)


model = SparseSparK(input_size=416)

dataset = torchvision.datasets.Caltech101(
"datasets/caltech101",
download=True,
transform=v2.Compose(
[
v2.Resize((416, 416)),
v2.RGB(),
v2.ToTensor(),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=4,
shuffle=True,
drop_last=True,
num_workers=8,
)

trainer = pl.Trainer(
max_epochs=30,
)

trainer.fit(
model=model,
train_dataloaders=dataloader,
)
3 changes: 2 additions & 1 deletion lightly/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The lightly.loss package provides loss functions for self-supervised learning. """
"""The lightly.loss package provides loss functions for self-supervised learning."""

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
Expand All @@ -16,6 +16,7 @@
from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity
from lightly.loss.ntx_ent_loss import NTXentLoss
from lightly.loss.pmsn_loss import PMSNCustomLoss, PMSNLoss
from lightly.loss.sparse_spark import SparKPatchReconLoss
from lightly.loss.swav_loss import SwaVLoss
from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss
from lightly.loss.tico_loss import TiCoLoss
Expand Down
89 changes: 89 additions & 0 deletions lightly/loss/sparse_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import torch
import torch.distributed as dist
from torch import Tensor, nn

from lightly.utils.dist import gather


class SparKPatchReconLoss(nn.Module):
"""Computes per-patch normalized reconstruction loss for masked regions.

Original paper: https://github.com/keyu-tian/SparK

Calculates L2 loss between reconstructed and original patches, normalized per-patch
to account for varying feature statistics. Loss is computed only on masked (inactive) regions.

Args:
eps: Small value for numerical stability. Default: 1e-6.
"""

def __init__(self, eps: float = 1e-6, gather_distributed: bool = False) -> None:
super().__init__()
if gather_distributed and not dist.is_available():
raise ValueError(
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)
self.eps = eps
self.gather_distributed = gather_distributed

def forward(
self,
inp_patches: Tensor,
rec_patches: Tensor,
active_mask: Tensor,
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute reconstruction loss and per-patch statistics.

Normalizes original patches based on per-patch mean and variance, then computes
L2 loss between normalized original and reconstructed patches. Averages loss
only over masked (active_mask=False) patches.

Args:
inp_patches: Original patches of shape (B, L, N) where B=batch, L=length, N=patch_dim.
rec_patches: Reconstructed patches of shape (B, L, N).
active_mask: Boolean mask of shape (B, 1, f, f) indicating active regions.
Must have 4 dimensions (2D spatial mask).

Returns:
Tuple of:
- recon_loss: Scalar tensor with averaged reconstruction loss on masked regions.
- mean: Per-patch mean of shape (B, L, 1).
- var: Per-patch standard deviation of shape (B, L, 1).

Raises:
ValueError: If active_mask does not have 4 dimensions.
"""
if active_mask.ndim != 4:
raise ValueError(
"active_mask must be non-flattened with shape (B, 1, f, f)"
)

mean = inp_patches.mean(dim=-1, keepdim=True)
var = (inp_patches.var(dim=-1, keepdim=True) + self.eps) ** 0.5

inp_normalized = (inp_patches - mean) / var

l2_loss = ((rec_patches - inp_normalized) ** 2).mean(dim=2)

non_active = active_mask.logical_not().int().view(active_mask.shape[0], -1)

local_numerator = (l2_loss * non_active).sum()
local_denominator = non_active.sum()

if self.gather_distributed and dist.is_available() and dist.is_initialized():
global_numerator = torch.cat(
gather(local_numerator.unsqueeze(0)), dim=0
).sum()
global_denominator = torch.cat(
gather(local_denominator.unsqueeze(0)), dim=0
).sum()
else:
global_numerator = local_numerator
global_denominator = local_denominator

recon_loss = global_numerator / (global_denominator + self.eps)
return recon_loss, mean, var
7 changes: 6 additions & 1 deletion lightly/models/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved


from lightly.models.modules.heads import (
BarlowTwinsProjectionHead,
BYOLPredictionHead,
Expand All @@ -30,6 +29,12 @@
SwaVPrototypes,
)
from lightly.models.modules.nn_memory_bank import NNMemoryBankModule
from lightly.models.modules.sparse_spark import (
SparKDensifier,
SparKMasker,
SparKOutputDecoder,
dense_model_to_sparse,
)
from lightly.utils import dependency as _dependency

if _dependency.torchvision_vit_available():
Expand Down
Loading