From 9a4bb89f505a24ebcae1342ab05f3399c42507b3 Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 29 Jul 2024 12:15:57 -0300 Subject: [PATCH] . --- ctgan/synthesizers/ctgan.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index bbcac8f3..2915d442 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -35,19 +35,25 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device) alpha = alpha.repeat(1, pac, real_data.size(1)) alpha = alpha.view(-1, real_data.size(1)) + print('alpha: ', alpha.device) interpolates = alpha * real_data + ((1 - alpha) * fake_data) + print('interpolates: ', interpolates.device) disc_interpolates = self(interpolates) - self.set_device(device) + print('disc_interpolates: ', disc_interpolates.device) + a = torch.ones(disc_interpolates.size(), device=device) + print('a: ', a.device) + gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates, - grad_outputs=torch.ones(disc_interpolates.size(), device=device), + grad_outputs=a, create_graph=True, retain_graph=True, only_inputs=True, )[0] + print('gradients: ', gradients.device) gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 gradient_penalty = ((gradients_view) ** 2).mean() * lambda_