diff --git a/lightning_examples/basic-gan/gan.py b/lightning_examples/basic-gan/gan.py index a12b6cd0c..42a85bc3f 100644 --- a/lightning_examples/basic-gan/gan.py +++ b/lightning_examples/basic-gan/gan.py @@ -88,7 +88,7 @@ def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) + layers.append(nn.LeakyReLU(0.01, inplace=True)) return layers self.model = nn.Sequential( @@ -193,7 +193,7 @@ def training_step(self, batch): # log sampled images sample_imgs = self.generated_imgs[:6] grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image("generated_images", grid, 0) + self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch) # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop @@ -201,7 +201,7 @@ def training_step(self, batch): valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy - g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) + g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) self.log("g_loss", g_loss, prog_bar=True) self.manual_backward(g_loss) optimizer_g.step() @@ -222,7 +222,7 @@ def training_step(self, batch): fake = torch.zeros(imgs.size(0), 1) fake = fake.type_as(imgs) - fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) + fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 @@ -232,6 +232,9 @@ def training_step(self, batch): optimizer_d.zero_grad() self.untoggle_optimizer(optimizer_d) + def validation_step(self, batch, batch_idx): + pass + def configure_optimizers(self): lr = self.hparams.lr b1 = self.hparams.b1 @@ -247,7 +250,7 @@ def on_validation_epoch_end(self): # log sampled images sample_imgs = self(z) grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image("generated_images", grid, self.current_epoch) + self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch) # %% @@ -263,4 +266,4 @@ def on_validation_epoch_end(self): # %% # Start tensorboard. # %load_ext tensorboard -# %tensorboard --logdir lightning_logs/ +# %tensorboard --logdir lightning_logs/ --samples_per_plugin=images=60