|
10 | 10 | # limitations under the License. |
11 | 11 |
|
12 | 12 |
|
| 13 | +import math |
13 | 14 | from typing import Callable, List, Optional, Tuple, Union |
14 | 15 |
|
15 | 16 | import torch |
@@ -66,7 +67,7 @@ def sample( |
66 | 67 | intermediate_steps: Optional[int] = 100, |
67 | 68 | conditioning: Optional[torch.Tensor] = None, |
68 | 69 | verbose: Optional[bool] = True, |
69 | | - ) -> torch.Tensor: |
| 70 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: |
70 | 71 | """ |
71 | 72 | Args: |
72 | 73 | input_noise: random noise, of the same shape as the desired sample. |
@@ -101,6 +102,168 @@ def sample( |
101 | 102 | else: |
102 | 103 | return image |
103 | 104 |
|
| 105 | + @torch.no_grad() |
| 106 | + def get_likelihood( |
| 107 | + self, |
| 108 | + inputs: torch.Tensor, |
| 109 | + diffusion_model: Callable[..., torch.Tensor], |
| 110 | + scheduler: Optional[Callable[..., torch.Tensor]] = None, |
| 111 | + save_intermediates: Optional[bool] = False, |
| 112 | + conditioning: Optional[torch.Tensor] = None, |
| 113 | + original_input_range: Optional[Tuple] = (0, 255), |
| 114 | + scaled_input_range: Optional[Tuple] = (0, 1), |
| 115 | + verbose: Optional[bool] = True, |
| 116 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 117 | + """ |
| 118 | + Computes the likelihoods for an input. |
| 119 | +
|
| 120 | + Args: |
| 121 | + inputs: input images, NxCxHxW[xD] |
| 122 | + diffusion_model: model to compute likelihood from |
| 123 | + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. |
| 124 | + save_intermediates: save the intermediate spatial KL maps |
| 125 | + conditioning: Conditioning for network input. |
| 126 | + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. |
| 127 | + scaled_input_range: the [min,max] intensity range of the input data after scaling. |
| 128 | + verbose: if true, prints the progression bar of the sampling process. |
| 129 | + """ |
| 130 | + |
| 131 | + if not scheduler: |
| 132 | + scheduler = self.scheduler |
| 133 | + if scheduler._get_name() != "DDPMScheduler": |
| 134 | + raise NotImplementedError( |
| 135 | + f"Likelihood computation is only compatible with DDPMScheduler," |
| 136 | + f" you are using {scheduler._get_name()}" |
| 137 | + ) |
| 138 | + if verbose and has_tqdm: |
| 139 | + progress_bar = tqdm(scheduler.timesteps) |
| 140 | + else: |
| 141 | + progress_bar = iter(scheduler.timesteps) |
| 142 | + intermediates = [] |
| 143 | + noise = torch.randn_like(inputs).to(inputs.device) |
| 144 | + total_kl = torch.zeros((inputs.shape[0])).to(inputs.device) |
| 145 | + for t in progress_bar: |
| 146 | + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() |
| 147 | + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
| 148 | + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) |
| 149 | + # get the model's predicted mean, and variance if it is predicted |
| 150 | + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: |
| 151 | + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) |
| 152 | + else: |
| 153 | + predicted_variance = None |
| 154 | + |
| 155 | + # 1. compute alphas, betas |
| 156 | + alpha_prod_t = scheduler.alphas_cumprod[t] |
| 157 | + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one |
| 158 | + beta_prod_t = 1 - alpha_prod_t |
| 159 | + beta_prod_t_prev = 1 - alpha_prod_t_prev |
| 160 | + |
| 161 | + # 2. compute predicted original sample from predicted noise also called |
| 162 | + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf |
| 163 | + if scheduler.prediction_type == "epsilon": |
| 164 | + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| 165 | + elif scheduler.prediction_type == "sample": |
| 166 | + pred_original_sample = model_output |
| 167 | + elif scheduler.prediction_type == "v_prediction": |
| 168 | + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output |
| 169 | + # 3. Clip "predicted x_0" |
| 170 | + if scheduler.clip_sample: |
| 171 | + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) |
| 172 | + |
| 173 | + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t |
| 174 | + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf |
| 175 | + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t |
| 176 | + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t |
| 177 | + |
| 178 | + # 5. Compute predicted previous sample µ_t |
| 179 | + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf |
| 180 | + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image |
| 181 | + |
| 182 | + # get the posterior mean and variance |
| 183 | + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) |
| 184 | + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) |
| 185 | + |
| 186 | + log_posterior_variance = torch.log(posterior_variance) |
| 187 | + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance |
| 188 | + |
| 189 | + if t == 0: |
| 190 | + # compute -log p(x_0|x_1) |
| 191 | + kl = -self._get_decoder_log_likelihood( |
| 192 | + inputs=inputs, |
| 193 | + means=predicted_mean, |
| 194 | + log_scales=0.5 * log_predicted_variance, |
| 195 | + original_input_range=original_input_range, |
| 196 | + scaled_input_range=scaled_input_range, |
| 197 | + ) |
| 198 | + else: |
| 199 | + # compute kl between two normals |
| 200 | + kl = 0.5 * ( |
| 201 | + -1.0 |
| 202 | + + log_predicted_variance |
| 203 | + - log_posterior_variance |
| 204 | + + torch.exp(log_posterior_variance - log_predicted_variance) |
| 205 | + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) |
| 206 | + ) |
| 207 | + total_kl += kl.view(kl.shape[0], -1).mean(axis=1) |
| 208 | + if save_intermediates: |
| 209 | + intermediates.append(kl.cpu()) |
| 210 | + |
| 211 | + if save_intermediates: |
| 212 | + return total_kl, intermediates |
| 213 | + else: |
| 214 | + return total_kl |
| 215 | + |
| 216 | + def _approx_standard_normal_cdf(self, x): |
| 217 | + """ |
| 218 | + A fast approximation of the cumulative distribution function of the |
| 219 | + standard normal. Code adapted from https://github.com/openai/improved-diffusion. |
| 220 | + """ |
| 221 | + |
| 222 | + return 0.5 * ( |
| 223 | + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) |
| 224 | + ) |
| 225 | + |
| 226 | + def _get_decoder_log_likelihood( |
| 227 | + self, |
| 228 | + inputs: torch.Tensor, |
| 229 | + means: torch.Tensor, |
| 230 | + log_scales: torch.Tensor, |
| 231 | + original_input_range: Optional[Tuple] = [0, 255], |
| 232 | + scaled_input_range: Optional[Tuple] = [0, 1], |
| 233 | + ) -> torch.Tensor: |
| 234 | + """ |
| 235 | + Compute the log-likelihood of a Gaussian distribution discretizing to a |
| 236 | + given image. Code adapted from https://github.com/openai/improved-diffusion. |
| 237 | +
|
| 238 | + Args: |
| 239 | + input: the target images. It is assumed that this was uint8 values, |
| 240 | + rescaled to the range [-1, 1]. |
| 241 | + means: the Gaussian mean Tensor. |
| 242 | + log_scales: the Gaussian log stddev Tensor. |
| 243 | + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. |
| 244 | + scaled_input_range: the [min,max] intensity range of the input data after scaling. |
| 245 | + """ |
| 246 | + assert inputs.shape == means.shape |
| 247 | + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( |
| 248 | + original_input_range[1] - original_input_range[0] |
| 249 | + ) |
| 250 | + centered_x = inputs - means |
| 251 | + inv_stdv = torch.exp(-log_scales) |
| 252 | + plus_in = inv_stdv * (centered_x + bin_width / 2) |
| 253 | + cdf_plus = self._approx_standard_normal_cdf(plus_in) |
| 254 | + min_in = inv_stdv * (centered_x - bin_width / 2) |
| 255 | + cdf_min = self._approx_standard_normal_cdf(min_in) |
| 256 | + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) |
| 257 | + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) |
| 258 | + cdf_delta = cdf_plus - cdf_min |
| 259 | + log_probs = torch.where( |
| 260 | + inputs < -0.999, |
| 261 | + log_cdf_plus, |
| 262 | + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), |
| 263 | + ) |
| 264 | + assert log_probs.shape == inputs.shape |
| 265 | + return log_probs |
| 266 | + |
104 | 267 |
|
105 | 268 | class LatentDiffusionInferer(DiffusionInferer): |
106 | 269 | """ |
@@ -201,3 +364,59 @@ def sample( |
201 | 364 |
|
202 | 365 | else: |
203 | 366 | return image |
| 367 | + |
| 368 | + @torch.no_grad() |
| 369 | + def get_likelihood( |
| 370 | + self, |
| 371 | + inputs: torch.Tensor, |
| 372 | + autoencoder_model: Callable[..., torch.Tensor], |
| 373 | + diffusion_model: Callable[..., torch.Tensor], |
| 374 | + scheduler: Optional[Callable[..., torch.Tensor]] = None, |
| 375 | + save_intermediates: Optional[bool] = False, |
| 376 | + conditioning: Optional[torch.Tensor] = None, |
| 377 | + original_input_range: Optional[Tuple] = (0, 255), |
| 378 | + scaled_input_range: Optional[Tuple] = (0, 1), |
| 379 | + verbose: Optional[bool] = True, |
| 380 | + resample_latent_likelihoods: Optional[bool] = False, |
| 381 | + resample_interpolation_mode: Optional[str] = "bilinear", |
| 382 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 383 | + """ |
| 384 | + Computes the likelihoods of the latent representations of the input. |
| 385 | +
|
| 386 | + Args: |
| 387 | + inputs: input images, NxCxHxW[xD] |
| 388 | + autoencoder_model: first stage model. |
| 389 | + diffusion_model: model to compute likelihood from |
| 390 | + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler |
| 391 | + save_intermediates: save the intermediate spatial KL maps |
| 392 | + conditioning: Conditioning for network input. |
| 393 | + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. |
| 394 | + scaled_input_range: the [min,max] intensity range of the input data after scaling. |
| 395 | + verbose: if true, prints the progression bar of the sampling process. |
| 396 | + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial |
| 397 | + dimension as the input images. |
| 398 | + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear' |
| 399 | + """ |
| 400 | + |
| 401 | + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor |
| 402 | + outputs = super().get_likelihood( |
| 403 | + inputs=latents, |
| 404 | + diffusion_model=diffusion_model, |
| 405 | + scheduler=scheduler, |
| 406 | + save_intermediates=save_intermediates, |
| 407 | + conditioning=conditioning, |
| 408 | + verbose=verbose, |
| 409 | + ) |
| 410 | + if save_intermediates and resample_latent_likelihoods: |
| 411 | + intermediates = outputs[1] |
| 412 | + from torchvision.transforms import Resize |
| 413 | + |
| 414 | + interpolation_modes = {"nearest": 0, "bilinear": 2} |
| 415 | + if resample_interpolation_mode not in interpolation_modes.keys(): |
| 416 | + raise ValueError( |
| 417 | + f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}" |
| 418 | + ) |
| 419 | + resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode]) |
| 420 | + intermediates = [resizer(x) for x in intermediates] |
| 421 | + outputs = (outputs[0], intermediates) |
| 422 | + return outputs |
0 commit comments