diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 5fdbc269..bbcac8f3 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -39,7 +39,7 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb interpolates = alpha * real_data + ((1 - alpha) * fake_data) disc_interpolates = self(interpolates) - + self.set_device(device) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates,