diff --git a/sagan/loss.py b/sagan/loss.py index 63f2067..7c8d5c4 100644 --- a/sagan/loss.py +++ b/sagan/loss.py @@ -46,7 +46,7 @@ def forward(self, fake_logits, mode, real_logits=None): def _generator_loss(self, fake_logits): return - fake_logits.mean() - def __discriminator_loss(self, real_logits, fake_logits): + def _discriminator_loss(self, real_logits, fake_logits): return - real_logits.mean() + fake_logits.mean() def get_interpolates(self, reals, fakes):