Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 9892933

Browse files
authored
Adds likelihood computation (#122)
* Fixes return type in sample * Adds method to compute posterior mean * Initial code for computing likelihood * Fixes bug in get_mean * Calculates mean/var from epsilon * Fixes bug in predicting input from noise * Adds decoder log-likelihood * Adds log-likelihood calculation for latent diffusion model * Fixes return type in sample * Adds method to compute posterior mean * Initial code for computing likelihood * Fixes bug in get_mean * Calculates mean/var from epsilon * Fixes bug in predicting input from noise * Adds decoder log-likelihood * Adds log-likelihood calculation for latent diffusion model * Adds tests * Adds latent tests * Pass input scalings to decoder calc * Fix arg and docstring * Fixes return type in sample * Adds method to compute posterior mean * Initial code for computing likelihood * Fixes bug in get_mean * Calculates mean/var from epsilon * Fixes bug in predicting input from noise * Adds decoder log-likelihood * Adds log-likelihood calculation for latent diffusion model * Adds tests * Adds latent tests * Adds method to compute posterior mean * Initial code for computing likelihood * Fixes bug in get_mean * Calculates mean/var from epsilon * Fixes bug in predicting input from noise * Adds decoder log-likelihood * Pass input scalings to decoder calc * Fix arg and docstring * Include v-prediction and use scheduler prediction_type attribute * Adds decorators for no_grad * Adds option to resample latent likelihoods spatially * Updates docstring
1 parent a693619 commit 9892933

File tree

4 files changed

+335
-4
lines changed

4 files changed

+335
-4
lines changed

generative/inferers/inferer.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212

13+
import math
1314
from typing import Callable, List, Optional, Tuple, Union
1415

1516
import torch
@@ -66,7 +67,7 @@ def sample(
6667
intermediate_steps: Optional[int] = 100,
6768
conditioning: Optional[torch.Tensor] = None,
6869
verbose: Optional[bool] = True,
69-
) -> torch.Tensor:
70+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
7071
"""
7172
Args:
7273
input_noise: random noise, of the same shape as the desired sample.
@@ -101,6 +102,168 @@ def sample(
101102
else:
102103
return image
103104

105+
@torch.no_grad()
106+
def get_likelihood(
107+
self,
108+
inputs: torch.Tensor,
109+
diffusion_model: Callable[..., torch.Tensor],
110+
scheduler: Optional[Callable[..., torch.Tensor]] = None,
111+
save_intermediates: Optional[bool] = False,
112+
conditioning: Optional[torch.Tensor] = None,
113+
original_input_range: Optional[Tuple] = (0, 255),
114+
scaled_input_range: Optional[Tuple] = (0, 1),
115+
verbose: Optional[bool] = True,
116+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
117+
"""
118+
Computes the likelihoods for an input.
119+
120+
Args:
121+
inputs: input images, NxCxHxW[xD]
122+
diffusion_model: model to compute likelihood from
123+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
124+
save_intermediates: save the intermediate spatial KL maps
125+
conditioning: Conditioning for network input.
126+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
127+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
128+
verbose: if true, prints the progression bar of the sampling process.
129+
"""
130+
131+
if not scheduler:
132+
scheduler = self.scheduler
133+
if scheduler._get_name() != "DDPMScheduler":
134+
raise NotImplementedError(
135+
f"Likelihood computation is only compatible with DDPMScheduler,"
136+
f" you are using {scheduler._get_name()}"
137+
)
138+
if verbose and has_tqdm:
139+
progress_bar = tqdm(scheduler.timesteps)
140+
else:
141+
progress_bar = iter(scheduler.timesteps)
142+
intermediates = []
143+
noise = torch.randn_like(inputs).to(inputs.device)
144+
total_kl = torch.zeros((inputs.shape[0])).to(inputs.device)
145+
for t in progress_bar:
146+
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
147+
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
148+
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
149+
# get the model's predicted mean, and variance if it is predicted
150+
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
151+
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
152+
else:
153+
predicted_variance = None
154+
155+
# 1. compute alphas, betas
156+
alpha_prod_t = scheduler.alphas_cumprod[t]
157+
alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
158+
beta_prod_t = 1 - alpha_prod_t
159+
beta_prod_t_prev = 1 - alpha_prod_t_prev
160+
161+
# 2. compute predicted original sample from predicted noise also called
162+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
163+
if scheduler.prediction_type == "epsilon":
164+
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
165+
elif scheduler.prediction_type == "sample":
166+
pred_original_sample = model_output
167+
elif scheduler.prediction_type == "v_prediction":
168+
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
169+
# 3. Clip "predicted x_0"
170+
if scheduler.clip_sample:
171+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
172+
173+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
174+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
175+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
176+
current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
177+
178+
# 5. Compute predicted previous sample µ_t
179+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
180+
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
181+
182+
# get the posterior mean and variance
183+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
184+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
185+
186+
log_posterior_variance = torch.log(posterior_variance)
187+
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
188+
189+
if t == 0:
190+
# compute -log p(x_0|x_1)
191+
kl = -self._get_decoder_log_likelihood(
192+
inputs=inputs,
193+
means=predicted_mean,
194+
log_scales=0.5 * log_predicted_variance,
195+
original_input_range=original_input_range,
196+
scaled_input_range=scaled_input_range,
197+
)
198+
else:
199+
# compute kl between two normals
200+
kl = 0.5 * (
201+
-1.0
202+
+ log_predicted_variance
203+
- log_posterior_variance
204+
+ torch.exp(log_posterior_variance - log_predicted_variance)
205+
+ ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
206+
)
207+
total_kl += kl.view(kl.shape[0], -1).mean(axis=1)
208+
if save_intermediates:
209+
intermediates.append(kl.cpu())
210+
211+
if save_intermediates:
212+
return total_kl, intermediates
213+
else:
214+
return total_kl
215+
216+
def _approx_standard_normal_cdf(self, x):
217+
"""
218+
A fast approximation of the cumulative distribution function of the
219+
standard normal. Code adapted from https://github.com/openai/improved-diffusion.
220+
"""
221+
222+
return 0.5 * (
223+
1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
224+
)
225+
226+
def _get_decoder_log_likelihood(
227+
self,
228+
inputs: torch.Tensor,
229+
means: torch.Tensor,
230+
log_scales: torch.Tensor,
231+
original_input_range: Optional[Tuple] = [0, 255],
232+
scaled_input_range: Optional[Tuple] = [0, 1],
233+
) -> torch.Tensor:
234+
"""
235+
Compute the log-likelihood of a Gaussian distribution discretizing to a
236+
given image. Code adapted from https://github.com/openai/improved-diffusion.
237+
238+
Args:
239+
input: the target images. It is assumed that this was uint8 values,
240+
rescaled to the range [-1, 1].
241+
means: the Gaussian mean Tensor.
242+
log_scales: the Gaussian log stddev Tensor.
243+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
244+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
245+
"""
246+
assert inputs.shape == means.shape
247+
bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (
248+
original_input_range[1] - original_input_range[0]
249+
)
250+
centered_x = inputs - means
251+
inv_stdv = torch.exp(-log_scales)
252+
plus_in = inv_stdv * (centered_x + bin_width / 2)
253+
cdf_plus = self._approx_standard_normal_cdf(plus_in)
254+
min_in = inv_stdv * (centered_x - bin_width / 2)
255+
cdf_min = self._approx_standard_normal_cdf(min_in)
256+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
257+
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
258+
cdf_delta = cdf_plus - cdf_min
259+
log_probs = torch.where(
260+
inputs < -0.999,
261+
log_cdf_plus,
262+
torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
263+
)
264+
assert log_probs.shape == inputs.shape
265+
return log_probs
266+
104267

105268
class LatentDiffusionInferer(DiffusionInferer):
106269
"""
@@ -201,3 +364,59 @@ def sample(
201364

202365
else:
203366
return image
367+
368+
@torch.no_grad()
369+
def get_likelihood(
370+
self,
371+
inputs: torch.Tensor,
372+
autoencoder_model: Callable[..., torch.Tensor],
373+
diffusion_model: Callable[..., torch.Tensor],
374+
scheduler: Optional[Callable[..., torch.Tensor]] = None,
375+
save_intermediates: Optional[bool] = False,
376+
conditioning: Optional[torch.Tensor] = None,
377+
original_input_range: Optional[Tuple] = (0, 255),
378+
scaled_input_range: Optional[Tuple] = (0, 1),
379+
verbose: Optional[bool] = True,
380+
resample_latent_likelihoods: Optional[bool] = False,
381+
resample_interpolation_mode: Optional[str] = "bilinear",
382+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
383+
"""
384+
Computes the likelihoods of the latent representations of the input.
385+
386+
Args:
387+
inputs: input images, NxCxHxW[xD]
388+
autoencoder_model: first stage model.
389+
diffusion_model: model to compute likelihood from
390+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
391+
save_intermediates: save the intermediate spatial KL maps
392+
conditioning: Conditioning for network input.
393+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
394+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
395+
verbose: if true, prints the progression bar of the sampling process.
396+
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
397+
dimension as the input images.
398+
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest' or 'bilinear'
399+
"""
400+
401+
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
402+
outputs = super().get_likelihood(
403+
inputs=latents,
404+
diffusion_model=diffusion_model,
405+
scheduler=scheduler,
406+
save_intermediates=save_intermediates,
407+
conditioning=conditioning,
408+
verbose=verbose,
409+
)
410+
if save_intermediates and resample_latent_likelihoods:
411+
intermediates = outputs[1]
412+
from torchvision.transforms import Resize
413+
414+
interpolation_modes = {"nearest": 0, "bilinear": 2}
415+
if resample_interpolation_mode not in interpolation_modes.keys():
416+
raise ValueError(
417+
f"resample_interpolation mode should be either nearest or bilinear, not {resample_interpolation_mode}"
418+
)
419+
resizer = Resize(size=inputs.shape[2:], interpolation=interpolation_modes[resample_interpolation_mode])
420+
intermediates = [resizer(x) for x in intermediates]
421+
outputs = (outputs[0], intermediates)
422+
return outputs

generative/networks/schedulers/ddpm.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
self.clip_sample = clip_sample
9191
self.variance_type = variance_type
9292

93-
# setable values
93+
# settable values
9494
self.num_inference_steps = None
9595
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
9696

@@ -109,9 +109,34 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to
109109
].copy()
110110
self.timesteps = torch.from_numpy(timesteps).to(device)
111111

112+
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
113+
"""
114+
Compute the mean of the posterior at timestep t.
115+
116+
Args:
117+
timestep: current timestep.
118+
x0: the noise-free input.
119+
x_t: the input noised to timestep t.
120+
121+
Returns:
122+
Returns the mean
123+
"""
124+
# these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),
125+
# (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)
126+
alpha_t = self.alphas[timestep]
127+
alpha_prod_t = self.alphas_cumprod[timestep]
128+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
129+
130+
x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)
131+
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
132+
133+
mean = x_0_coefficient * x_0 + x_t_coefficient * x_t
134+
135+
return mean
136+
112137
def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor:
113138
"""
114-
Compute the variance.
139+
Compute the variance of the posterior at timestep t.
115140
116141
Args:
117142
timestep: current timestep.
@@ -127,7 +152,6 @@ def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor
127152
# and sample from it to get previous sample
128153
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
129154
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
130-
131155
# hacks - were probably added for training stability
132156
if self.variance_type == "fixed_small":
133157
variance = torch.clamp(variance, min=1e-20)

tests/test_diffusion_inferer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,37 @@ def test_sampler_conditioned(self, model_params, input_shape):
141141
)
142142
self.assertEqual(len(intermediates), 10)
143143

144+
@parameterized.expand(TEST_CASES)
145+
def test_get_likelihood(self, model_params, input_shape):
146+
model = DiffusionModelUNet(**model_params)
147+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
148+
model.to(device)
149+
model.eval()
150+
input = torch.randn(input_shape).to(device)
151+
scheduler = DDPMScheduler(
152+
num_train_timesteps=10,
153+
)
154+
inferer = DiffusionInferer(scheduler=scheduler)
155+
scheduler.set_timesteps(num_inference_steps=10)
156+
likelihood, intermediates = inferer.get_likelihood(
157+
inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True
158+
)
159+
self.assertEqual(intermediates[0].shape, input.shape)
160+
self.assertEqual(likelihood.shape[0], input.shape[0])
161+
162+
def test_normal_cdf(self):
163+
from scipy.stats import norm
164+
165+
scheduler = DDPMScheduler(
166+
num_train_timesteps=10,
167+
)
168+
inferer = DiffusionInferer(scheduler=scheduler)
169+
170+
x = torch.linspace(-10, 10, 20)
171+
cdf_approx = inferer._approx_standard_normal_cdf(x)
172+
cdf_true = norm.cdf(x)
173+
torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)
174+
144175

145176
if __name__ == "__main__":
146177
unittest.main()

0 commit comments

Comments
 (0)