diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py index 1a5c73c..47d4363 100644 --- a/imagen_pytorch/elucidated_imagen.py +++ b/imagen_pytorch/elucidated_imagen.py @@ -284,8 +284,11 @@ def one_unet_in_gpu(self, unet_number = None, unet = None): if exists(unet_number): unet = self.unets[unet_number - 1] + cpu = torch.device('cpu') + devices = [module_device(unet) for unet in self.unets] - self.unets.cpu() + + self.unets.to(cpu) unet.to(self.device) yield @@ -568,6 +571,7 @@ def sample( return_all_unet_outputs = False, return_pil_images = False, use_tqdm = True, + use_one_unet_in_gpu = True, device = None, ): device = default(device, self.device) @@ -649,7 +653,7 @@ def sample( assert not isinstance(unet, NullUnet), 'cannot sample from null unet' - context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() + context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext() with context: lowres_cond_img = lowres_noise_times = None diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index 3ae28d2..b3fa919 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -2012,8 +2012,11 @@ def one_unet_in_gpu(self, unet_number = None, unet = None): if exists(unet_number): unet = self.unets[unet_number - 1] + cpu = torch.device('cpu') + devices = [module_device(unet) for unet in self.unets] - self.unets.cpu() + + self.unets.to(cpu) unet.to(self.device) yield @@ -2305,7 +2308,8 @@ def sample( return_all_unet_outputs = False, return_pil_images = False, device = None, - use_tqdm = True + use_tqdm = True, + use_one_unet_in_gpu = True ): device = default(device, self.device) self.reset_unets_all_one_device(device = device) @@ -2389,7 +2393,7 @@ def sample( assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets' - context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() + context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext() with context: # video kwargs diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py index 9d9c3d0..0597330 100644 --- a/imagen_pytorch/trainer.py +++ b/imagen_pytorch/trainer.py @@ -341,8 +341,7 @@ def __init__( optimizer = Lion( unet.parameters(), lr = unet_lr, - betas = (beta1, beta2), - use_triton = True + betas = (beta1, beta2) ) else: optimizer = Adam( diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 53bfe2e..c38eab6 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.22.2' +__version__ = '1.22.4'