Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 56089b5

Browse files
Add tutorial about Diff-SCM (#306)
* initial commit anomaly detection with gradient guidance * brats 2d healthy/unhealthy loader * first reversed loop for DDIM is implemented. Classifier guidance is on the way * anomaly detection tutorial is complete, the training needs to be checked * cleaned up the classification network for gradient guidance * cleaning up * ddim clean-up * Add tutorial * remove load_2d * take out encoder for now * Fixing Walter's comments, changes in data transfom * DDIM and unet changes --------- Co-authored-by: Julia <julia.wolleb@unibas.ch>
1 parent fd74813 commit 56089b5

File tree

3 files changed

+1424
-0
lines changed

3 files changed

+1424
-0
lines changed

generative/networks/schedulers/ddim.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,66 @@ def step(
225225

226226
return pred_prev_sample, pred_original_sample
227227

228+
def reversed_step(
229+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor
230+
) -> tuple[torch.Tensor, torch.Tensor]:
231+
"""
232+
Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion
233+
process from the learned model outputs (most often the predicted noise).
234+
235+
Args:
236+
model_output: direct output from learned diffusion model.
237+
timestep: current discrete timestep in the diffusion chain.
238+
sample: current instance of sample being created by diffusion process.
239+
240+
Returns:
241+
pred_prev_sample: Predicted previous sample
242+
pred_original_sample: Predicted original sample
243+
"""
244+
# See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
245+
246+
# Notation (<variable name> -> <name in paper>
247+
# - model_output -> e_theta(x_t, t)
248+
# - pred_original_sample -> f_theta(x_t, t) or x_0
249+
# - std_dev_t -> sigma_t
250+
# - eta -> η
251+
# - pred_sample_direction -> "direction pointing to x_t"
252+
# - pred_post_sample -> "x_t+1"
253+
254+
# 1. get previous step value (=t+1)
255+
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
256+
257+
# 2. compute alphas, betas at timestep t+1
258+
alpha_prod_t = self.alphas_cumprod[timestep]
259+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
260+
261+
beta_prod_t = 1 - alpha_prod_t
262+
263+
# 3. compute predicted original sample from predicted noise also called
264+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
265+
266+
if self.prediction_type == "epsilon":
267+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
268+
pred_epsilon = model_output
269+
elif self.prediction_type == "sample":
270+
pred_original_sample = model_output
271+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
272+
elif self.prediction_type == "v_prediction":
273+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
274+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
275+
276+
# 4. Clip "predicted x_0"
277+
if self.clip_sample:
278+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
279+
280+
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
281+
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
282+
283+
# 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
284+
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
285+
286+
return pred_post_sample, pred_original_sample
287+
228288
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
229289
"""
230290
Add noise to the original samples.

0 commit comments

Comments
 (0)