Skip to content

Commit 3952707

Browse files
committed
🐛 fix: add lambda coff to prevent overwhelming
1 parent 516867a commit 3952707

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

megdiffusion/diffusion/gaussion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,16 @@ def _mse_loss(x_start, x_t, t):
295295
elif self.model_mean_type == "EPSILON":
296296
target = noise
297297

298+
# Note don't use F.nn.square_loss() here
298299
return mean_flat((target - model_output) ** 2)
299300

300301
if self.loss_type == "VLB":
301302
loss = _vlb_loss(x_start, x_t, t)
302303
elif self.loss_type == "SIMPLE":
303304
loss = _mse_loss(x_start, x_t, t)
304-
elif self.loss_type == "HYBRID":
305-
loss = _vlb_loss(x_start, x_t, t) + _mse_loss(x_start, x_t, t)
305+
elif self.loss_type == "HYBRID": # IDDPM Eq. (16)
306+
# set lambda = 0.001 to prevent L_{vlb} from overwhelming L_{simple}.
307+
loss = 0.001 * _vlb_loss(x_start, x_t, t) + _mse_loss(x_start, x_t, t)
306308
else:
307309
raise NotImplementedError(self.loss_type)
308310

0 commit comments

Comments
 (0)