|
| 1 | +# Copyright 2025 Berrada et al. |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | + |
| 9 | +def normalize_tensor(in_feat, eps=1e-10): |
| 10 | + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) |
| 11 | + return in_feat / (norm_factor + eps) |
| 12 | + |
| 13 | + |
| 14 | +def cross_normalize(input, target, eps=1e-10): |
| 15 | + norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True)) |
| 16 | + return input / (norm_factor + eps), target / (norm_factor + eps) |
| 17 | + |
| 18 | + |
| 19 | +def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02): |
| 20 | + opening = int(np.ceil(opening / down_f)) |
| 21 | + closing = int(np.ceil(closing / down_f)) |
| 22 | + if opening == 2: |
| 23 | + opening = 3 |
| 24 | + if closing == 2: |
| 25 | + closing = 1 |
| 26 | + |
| 27 | + # replace quantile with kth value here. |
| 28 | + feat_flat = feat.flatten(-2, -1) |
| 29 | + k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant)) |
| 30 | + q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None] |
| 31 | + q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None] |
| 32 | + |
| 33 | + m = 2 * feat_flat.std(-1)[..., None, None].detach() |
| 34 | + mask = (q1 - m < feat) * (feat < q2 + m) |
| 35 | + |
| 36 | + # dilate the mask. |
| 37 | + mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing |
| 38 | + mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening |
| 39 | + feat = feat * mask |
| 40 | + return mask, feat |
| 41 | + |
| 42 | + |
| 43 | +class LatentPerceptualLoss(nn.Module): |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + vae, |
| 47 | + loss_type="mse", |
| 48 | + grad_ckpt=True, |
| 49 | + pow_law=False, |
| 50 | + norm_type="default", |
| 51 | + num_mid_blocks=4, |
| 52 | + feature_type="feature", |
| 53 | + remove_outliers=True, |
| 54 | + ): |
| 55 | + super().__init__() |
| 56 | + self.vae = vae |
| 57 | + self.decoder = self.vae.decoder |
| 58 | + # Store scaling factors as tensors on the correct device |
| 59 | + device = next(self.vae.parameters()).device |
| 60 | + |
| 61 | + # Get scaling factors with proper defaults and handle None values |
| 62 | + scale_factor = getattr(self.vae.config, "scaling_factor", None) |
| 63 | + shift_factor = getattr(self.vae.config, "shift_factor", None) |
| 64 | + |
| 65 | + # Convert to tensors with proper defaults |
| 66 | + self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device) |
| 67 | + self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device) |
| 68 | + |
| 69 | + self.gradient_checkpointing = grad_ckpt |
| 70 | + self.pow_law = pow_law |
| 71 | + self.norm_type = norm_type.lower() |
| 72 | + self.outlier_mask = remove_outliers |
| 73 | + self.last_feature_stats = [] # Store feature statistics for logging |
| 74 | + |
| 75 | + assert feature_type in ["feature", "image"] |
| 76 | + self.feature_type = feature_type |
| 77 | + |
| 78 | + assert self.norm_type in ["default", "shared", "batch"] |
| 79 | + assert num_mid_blocks >= 0 and num_mid_blocks <= 4 |
| 80 | + self.n_blocks = num_mid_blocks |
| 81 | + |
| 82 | + assert loss_type in ["mse", "l1"] |
| 83 | + if loss_type == "mse": |
| 84 | + self.loss_fn = nn.MSELoss(reduction="none") |
| 85 | + elif loss_type == "l1": |
| 86 | + self.loss_fn = nn.L1Loss(reduction="none") |
| 87 | + |
| 88 | + def get_features(self, z, latent_embeds=None, disable_grads=False): |
| 89 | + with torch.set_grad_enabled(not disable_grads): |
| 90 | + if self.gradient_checkpointing and not disable_grads: |
| 91 | + |
| 92 | + def create_custom_forward(module): |
| 93 | + def custom_forward(*inputs): |
| 94 | + return module(*inputs) |
| 95 | + |
| 96 | + return custom_forward |
| 97 | + |
| 98 | + features = [] |
| 99 | + upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype |
| 100 | + sample = z |
| 101 | + sample = self.decoder.conv_in(sample) |
| 102 | + |
| 103 | + # middle |
| 104 | + sample = torch.utils.checkpoint.checkpoint( |
| 105 | + create_custom_forward(self.decoder.mid_block), |
| 106 | + sample, |
| 107 | + latent_embeds, |
| 108 | + use_reentrant=False, |
| 109 | + ) |
| 110 | + sample = sample.to(upscale_dtype) |
| 111 | + features.append(sample) |
| 112 | + |
| 113 | + # up |
| 114 | + for up_block in self.decoder.up_blocks[: self.n_blocks]: |
| 115 | + sample = torch.utils.checkpoint.checkpoint( |
| 116 | + create_custom_forward(up_block), |
| 117 | + sample, |
| 118 | + latent_embeds, |
| 119 | + use_reentrant=False, |
| 120 | + ) |
| 121 | + features.append(sample) |
| 122 | + return features |
| 123 | + else: |
| 124 | + features = [] |
| 125 | + upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype |
| 126 | + sample = z |
| 127 | + sample = self.decoder.conv_in(sample) |
| 128 | + |
| 129 | + # middle |
| 130 | + sample = self.decoder.mid_block(sample, latent_embeds) |
| 131 | + sample = sample.to(upscale_dtype) |
| 132 | + features.append(sample) |
| 133 | + |
| 134 | + # up |
| 135 | + for up_block in self.decoder.up_blocks[: self.n_blocks]: |
| 136 | + sample = up_block(sample, latent_embeds) |
| 137 | + features.append(sample) |
| 138 | + return features |
| 139 | + |
| 140 | + def get_loss(self, input, target, get_hist=False): |
| 141 | + if self.feature_type == "feature": |
| 142 | + inp_f = self.get_features(self.shift + input / self.scale) |
| 143 | + tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True) |
| 144 | + losses = [] |
| 145 | + self.last_feature_stats = [] # Reset feature stats |
| 146 | + |
| 147 | + for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)): |
| 148 | + my = torch.ones_like(y).bool() |
| 149 | + outlier_ratio = 0.0 |
| 150 | + |
| 151 | + if self.outlier_mask: |
| 152 | + with torch.no_grad(): |
| 153 | + if i == 2: |
| 154 | + my, y = remove_outliers(y, down_f=2) |
| 155 | + outlier_ratio = 1.0 - my.float().mean().item() |
| 156 | + elif i in [3, 4, 5]: |
| 157 | + my, y = remove_outliers(y, down_f=1) |
| 158 | + outlier_ratio = 1.0 - my.float().mean().item() |
| 159 | + |
| 160 | + # Store feature statistics before normalization |
| 161 | + with torch.no_grad(): |
| 162 | + stats = { |
| 163 | + "mean": y.mean().item(), |
| 164 | + "std": y.std().item(), |
| 165 | + "outlier_ratio": outlier_ratio, |
| 166 | + } |
| 167 | + self.last_feature_stats.append(stats) |
| 168 | + |
| 169 | + # normalize feature tensors |
| 170 | + if self.norm_type == "default": |
| 171 | + x = normalize_tensor(x) |
| 172 | + y = normalize_tensor(y) |
| 173 | + elif self.norm_type == "shared": |
| 174 | + x, y = cross_normalize(x, y, eps=1e-6) |
| 175 | + |
| 176 | + term_loss = self.loss_fn(x, y) * my |
| 177 | + # reduce loss term |
| 178 | + loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0 |
| 179 | + term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3)) |
| 180 | + losses.append(term_loss.mean((1,))) |
| 181 | + |
| 182 | + if get_hist: |
| 183 | + return losses |
| 184 | + else: |
| 185 | + loss = sum(losses) |
| 186 | + return loss / len(inp_f) |
| 187 | + elif self.feature_type == "image": |
| 188 | + inp_f = self.vae.decode(input / self.scale).sample |
| 189 | + tar_f = self.vae.decode(target / self.scale).sample |
| 190 | + return F.mse_loss(inp_f, tar_f) |
| 191 | + |
| 192 | + def get_first_conv(self, z): |
| 193 | + sample = self.decoder.conv_in(z) |
| 194 | + return sample |
| 195 | + |
| 196 | + def get_first_block(self, z): |
| 197 | + sample = self.decoder.conv_in(z) |
| 198 | + sample = self.decoder.mid_block(sample) |
| 199 | + for resnet in self.decoder.up_blocks[0].resnets: |
| 200 | + sample = resnet(sample, None) |
| 201 | + return sample |
| 202 | + |
| 203 | + def get_first_layer(self, input, target, target_layer="conv"): |
| 204 | + if target_layer == "conv": |
| 205 | + feat_in = self.get_first_conv(input) |
| 206 | + with torch.no_grad(): |
| 207 | + feat_tar = self.get_first_conv(target) |
| 208 | + else: |
| 209 | + feat_in = self.get_first_block(input) |
| 210 | + with torch.no_grad(): |
| 211 | + feat_tar = self.get_first_block(target) |
| 212 | + |
| 213 | + feat_in, feat_tar = cross_normalize(feat_in, feat_tar) |
| 214 | + |
| 215 | + return F.mse_loss(feat_in, feat_tar, reduction="mean") |
0 commit comments