Skip to content

Commit f269d8f

Browse files
authored
minor fix in diffusion edm schedule (#560)
1 parent c1407df commit f269d8f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bayesflow/networks/diffusion_model/schedules/edm_noise_schedule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
5353
def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor:
5454
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
5555
if training:
56-
# SNR = -dist.icdf(t_trunc) # negative seems to be wrong in the Kingma paper
56+
# SNR = dist.icdf(1-t) # Kingma paper wrote -F(t) but this seems to be wrong
5757
loc = -2 * self.p_mean
5858
scale = 2 * self.p_std
5959
snr = loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2)
@@ -67,11 +67,11 @@ def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor:
6767
def get_t_from_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor:
6868
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6969
if training:
70-
# SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr) # negative seems to be wrong in the Kingma paper
70+
# SNR = dist.icdf(1-t) => t = 1-dist.cdf(snr) # Kingma paper wrote -F(t) but this seems to be wrong
7171
loc = -2 * self.p_mean
7272
scale = 2 * self.p_std
7373
x = log_snr_t
74-
t = 0.5 * (1 + ops.erf((x - loc) / (scale * math.sqrt(2.0))))
74+
t = 1 - 0.5 * (1 + ops.erf((x - loc) / (scale * math.sqrt(2.0))))
7575
else: # sampling
7676
# SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
7777
# => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho)))

0 commit comments

Comments
 (0)