Skip to content

Commit 2bb640f

Browse files
kashifsayakpaul
andauthored
[Research] Latent Perceptual Loss (LPL) for Stable Diffusion XL (huggingface#11573)
* initial * added readme * fix formatting * added logging * formatting * use config * debug * better * handle SNR * floats have no item() * remove debug * formatting * add paper link * acknowledge reference source * rename script --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 2dc9d2a commit 2bb640f

File tree

3 files changed

+1994
-0
lines changed

3 files changed

+1994
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
2+
3+
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
4+
5+
## Overview
6+
7+
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
8+
9+
- Loss of fine details in generated images
10+
- Inconsistent image quality
11+
- Structural artifacts
12+
- Reduced sharpness and realism
13+
14+
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
15+
16+
- Improved image quality and consistency (6-20% FID improvement)
17+
- Better preservation of fine details
18+
- More stable training, especially at high noise levels
19+
- Better handling of structural information
20+
- Sharper and more realistic textures
21+
22+
## Implementation Details
23+
24+
The LPL implementation follows the paper's methodology and includes several key features:
25+
26+
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
27+
- Middle block features
28+
- Up block features (configurable number of blocks)
29+
- Proper gradient checkpointing for memory efficiency
30+
- Features are extracted only for timesteps below the threshold (high SNR)
31+
32+
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
33+
- `default`: Normalize each feature map independently
34+
- `shared`: Cross-normalize features using target statistics (recommended)
35+
- `batch`: Batch-wise normalization
36+
37+
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
38+
- Quantile-based filtering (2% quantiles)
39+
- Morphological operations (opening/closing)
40+
- Adaptive thresholding based on standard deviation
41+
42+
4. **Loss Types**:
43+
- MSE loss (default)
44+
- L1 loss
45+
- Optional power law weighting (2^(-i) for layer i)
46+
47+
## Usage
48+
49+
To use LPL in your training, add the following arguments to your training command:
50+
51+
```bash
52+
python examples/research_projects/lpl/train_sdxl_lpl.py \
53+
--use_lpl \
54+
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
55+
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
56+
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
57+
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
58+
--lpl_pow_law \ # Use power law weighting for layers
59+
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
60+
--lpl_remove_outliers \ # Remove outliers in feature maps
61+
--lpl_scale \ # Scale LPL loss by noise level weights
62+
--lpl_start 0 \ # Step to start applying LPL
63+
# ... other training arguments ...
64+
```
65+
66+
### Key Parameters
67+
68+
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
69+
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
70+
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
71+
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
72+
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
73+
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
74+
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
75+
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
76+
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
77+
78+
## Recommendations
79+
80+
1. **Starting Point** (based on paper results):
81+
```bash
82+
--use_lpl \
83+
--lpl_weight 1.0 \
84+
--lpl_t_threshold 200 \
85+
--lpl_loss_type mse \
86+
--lpl_norm_type shared \
87+
--lpl_pow_law \
88+
--lpl_num_blocks 4 \
89+
--lpl_remove_outliers \
90+
--lpl_scale
91+
```
92+
93+
2. **Memory Efficiency**:
94+
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
95+
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
96+
- Consider using `--lpl_scale` to focus on more important timesteps
97+
- Features are extracted only for timesteps below threshold to save memory
98+
99+
3. **Quality vs Speed**:
100+
- Higher `lpl_weight` (1.0-2.0) for better quality
101+
- Lower `lpl_t_threshold` (100-200) for faster training
102+
- Use `lpl_remove_outliers` for more stable training
103+
- `lpl_norm_type shared` provides best quality/speed trade-off
104+
105+
## Technical Details
106+
107+
### Feature Extraction
108+
109+
The LPL implementation extracts features from the VAE decoder in the following order:
110+
1. Middle block output
111+
2. Up block outputs (configurable number of blocks)
112+
113+
Each feature map is processed with:
114+
1. Optional outlier removal (2% quantiles, morphological operations)
115+
2. Feature normalization (shared statistics recommended)
116+
3. Loss calculation (MSE or L1)
117+
4. Optional power law weighting (2^(-i) for layer i)
118+
119+
### Loss Calculation
120+
121+
For each feature map:
122+
1. Features are normalized according to the chosen strategy
123+
2. Loss is calculated between normalized features
124+
3. Outliers are masked out (if enabled)
125+
4. Loss is weighted by layer depth (if power law enabled)
126+
5. Final loss is averaged across all layers
127+
128+
### Memory Considerations
129+
130+
- Gradient checkpointing is used by default
131+
- Features are extracted only for timesteps below the threshold
132+
- Outlier removal is done in-place to save memory
133+
- Feature normalization is done efficiently using vectorized operations
134+
- Memory usage scales linearly with number of blocks used
135+
136+
## Results
137+
138+
Based on the paper's findings, LPL provides:
139+
- 6-20% improvement in FID scores
140+
- Better preservation of fine details
141+
- More realistic textures and structures
142+
- Improved consistency across different resolutions
143+
- Better performance on both small and large datasets
144+
145+
## Citation
146+
147+
If you use this implementation in your research, please cite:
148+
149+
```bibtex
150+
@inproceedings{berrada2025boosting,
151+
title={Boosting Latent Diffusion with Perceptual Objectives},
152+
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
153+
booktitle={The Thirteenth International Conference on Learning Representations},
154+
year={2025},
155+
url={https://openreview.net/forum?id=y4DtzADzd1}
156+
}
157+
```
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2025 Berrada et al.
2+
3+
import numpy as np
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
9+
def normalize_tensor(in_feat, eps=1e-10):
10+
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
11+
return in_feat / (norm_factor + eps)
12+
13+
14+
def cross_normalize(input, target, eps=1e-10):
15+
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
16+
return input / (norm_factor + eps), target / (norm_factor + eps)
17+
18+
19+
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
20+
opening = int(np.ceil(opening / down_f))
21+
closing = int(np.ceil(closing / down_f))
22+
if opening == 2:
23+
opening = 3
24+
if closing == 2:
25+
closing = 1
26+
27+
# replace quantile with kth value here.
28+
feat_flat = feat.flatten(-2, -1)
29+
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
30+
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
31+
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
32+
33+
m = 2 * feat_flat.std(-1)[..., None, None].detach()
34+
mask = (q1 - m < feat) * (feat < q2 + m)
35+
36+
# dilate the mask.
37+
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
38+
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
39+
feat = feat * mask
40+
return mask, feat
41+
42+
43+
class LatentPerceptualLoss(nn.Module):
44+
def __init__(
45+
self,
46+
vae,
47+
loss_type="mse",
48+
grad_ckpt=True,
49+
pow_law=False,
50+
norm_type="default",
51+
num_mid_blocks=4,
52+
feature_type="feature",
53+
remove_outliers=True,
54+
):
55+
super().__init__()
56+
self.vae = vae
57+
self.decoder = self.vae.decoder
58+
# Store scaling factors as tensors on the correct device
59+
device = next(self.vae.parameters()).device
60+
61+
# Get scaling factors with proper defaults and handle None values
62+
scale_factor = getattr(self.vae.config, "scaling_factor", None)
63+
shift_factor = getattr(self.vae.config, "shift_factor", None)
64+
65+
# Convert to tensors with proper defaults
66+
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
67+
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
68+
69+
self.gradient_checkpointing = grad_ckpt
70+
self.pow_law = pow_law
71+
self.norm_type = norm_type.lower()
72+
self.outlier_mask = remove_outliers
73+
self.last_feature_stats = [] # Store feature statistics for logging
74+
75+
assert feature_type in ["feature", "image"]
76+
self.feature_type = feature_type
77+
78+
assert self.norm_type in ["default", "shared", "batch"]
79+
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
80+
self.n_blocks = num_mid_blocks
81+
82+
assert loss_type in ["mse", "l1"]
83+
if loss_type == "mse":
84+
self.loss_fn = nn.MSELoss(reduction="none")
85+
elif loss_type == "l1":
86+
self.loss_fn = nn.L1Loss(reduction="none")
87+
88+
def get_features(self, z, latent_embeds=None, disable_grads=False):
89+
with torch.set_grad_enabled(not disable_grads):
90+
if self.gradient_checkpointing and not disable_grads:
91+
92+
def create_custom_forward(module):
93+
def custom_forward(*inputs):
94+
return module(*inputs)
95+
96+
return custom_forward
97+
98+
features = []
99+
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
100+
sample = z
101+
sample = self.decoder.conv_in(sample)
102+
103+
# middle
104+
sample = torch.utils.checkpoint.checkpoint(
105+
create_custom_forward(self.decoder.mid_block),
106+
sample,
107+
latent_embeds,
108+
use_reentrant=False,
109+
)
110+
sample = sample.to(upscale_dtype)
111+
features.append(sample)
112+
113+
# up
114+
for up_block in self.decoder.up_blocks[: self.n_blocks]:
115+
sample = torch.utils.checkpoint.checkpoint(
116+
create_custom_forward(up_block),
117+
sample,
118+
latent_embeds,
119+
use_reentrant=False,
120+
)
121+
features.append(sample)
122+
return features
123+
else:
124+
features = []
125+
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
126+
sample = z
127+
sample = self.decoder.conv_in(sample)
128+
129+
# middle
130+
sample = self.decoder.mid_block(sample, latent_embeds)
131+
sample = sample.to(upscale_dtype)
132+
features.append(sample)
133+
134+
# up
135+
for up_block in self.decoder.up_blocks[: self.n_blocks]:
136+
sample = up_block(sample, latent_embeds)
137+
features.append(sample)
138+
return features
139+
140+
def get_loss(self, input, target, get_hist=False):
141+
if self.feature_type == "feature":
142+
inp_f = self.get_features(self.shift + input / self.scale)
143+
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
144+
losses = []
145+
self.last_feature_stats = [] # Reset feature stats
146+
147+
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
148+
my = torch.ones_like(y).bool()
149+
outlier_ratio = 0.0
150+
151+
if self.outlier_mask:
152+
with torch.no_grad():
153+
if i == 2:
154+
my, y = remove_outliers(y, down_f=2)
155+
outlier_ratio = 1.0 - my.float().mean().item()
156+
elif i in [3, 4, 5]:
157+
my, y = remove_outliers(y, down_f=1)
158+
outlier_ratio = 1.0 - my.float().mean().item()
159+
160+
# Store feature statistics before normalization
161+
with torch.no_grad():
162+
stats = {
163+
"mean": y.mean().item(),
164+
"std": y.std().item(),
165+
"outlier_ratio": outlier_ratio,
166+
}
167+
self.last_feature_stats.append(stats)
168+
169+
# normalize feature tensors
170+
if self.norm_type == "default":
171+
x = normalize_tensor(x)
172+
y = normalize_tensor(y)
173+
elif self.norm_type == "shared":
174+
x, y = cross_normalize(x, y, eps=1e-6)
175+
176+
term_loss = self.loss_fn(x, y) * my
177+
# reduce loss term
178+
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
179+
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
180+
losses.append(term_loss.mean((1,)))
181+
182+
if get_hist:
183+
return losses
184+
else:
185+
loss = sum(losses)
186+
return loss / len(inp_f)
187+
elif self.feature_type == "image":
188+
inp_f = self.vae.decode(input / self.scale).sample
189+
tar_f = self.vae.decode(target / self.scale).sample
190+
return F.mse_loss(inp_f, tar_f)
191+
192+
def get_first_conv(self, z):
193+
sample = self.decoder.conv_in(z)
194+
return sample
195+
196+
def get_first_block(self, z):
197+
sample = self.decoder.conv_in(z)
198+
sample = self.decoder.mid_block(sample)
199+
for resnet in self.decoder.up_blocks[0].resnets:
200+
sample = resnet(sample, None)
201+
return sample
202+
203+
def get_first_layer(self, input, target, target_layer="conv"):
204+
if target_layer == "conv":
205+
feat_in = self.get_first_conv(input)
206+
with torch.no_grad():
207+
feat_tar = self.get_first_conv(target)
208+
else:
209+
feat_in = self.get_first_block(input)
210+
with torch.no_grad():
211+
feat_tar = self.get_first_block(target)
212+
213+
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
214+
215+
return F.mse_loss(feat_in, feat_tar, reduction="mean")

0 commit comments

Comments
 (0)