Skip to content

Commit

Permalink
allow for disabling memory efficient way of keeping only one unet in …
Browse files Browse the repository at this point in the history
…gpu when sampling
  • Loading branch information
lucidrains committed Mar 8, 2023
1 parent 989d4df commit 726c11a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
8 changes: 6 additions & 2 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,11 @@ def one_unet_in_gpu(self, unet_number = None, unet = None):
if exists(unet_number):
unet = self.unets[unet_number - 1]

cpu = torch.device('cpu')

devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()

self.unets.to(cpu)
unet.to(self.device)

yield
Expand Down Expand Up @@ -568,6 +571,7 @@ def sample(
return_all_unet_outputs = False,
return_pil_images = False,
use_tqdm = True,
use_one_unet_in_gpu = True,
device = None,
):
device = default(device, self.device)
Expand Down Expand Up @@ -649,7 +653,7 @@ def sample(

assert not isinstance(unet, NullUnet), 'cannot sample from null unet'

context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()

with context:
lowres_cond_img = lowres_noise_times = None
Expand Down
10 changes: 7 additions & 3 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,8 +2012,11 @@ def one_unet_in_gpu(self, unet_number = None, unet = None):
if exists(unet_number):
unet = self.unets[unet_number - 1]

cpu = torch.device('cpu')

devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()

self.unets.to(cpu)
unet.to(self.device)

yield
Expand Down Expand Up @@ -2305,7 +2308,8 @@ def sample(
return_all_unet_outputs = False,
return_pil_images = False,
device = None,
use_tqdm = True
use_tqdm = True,
use_one_unet_in_gpu = True
):
device = default(device, self.device)
self.reset_unets_all_one_device(device = device)
Expand Down Expand Up @@ -2389,7 +2393,7 @@ def sample(

assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'

context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()

with context:
# video kwargs
Expand Down
3 changes: 1 addition & 2 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ def __init__(
optimizer = Lion(
unet.parameters(),
lr = unet_lr,
betas = (beta1, beta2),
use_triton = True
betas = (beta1, beta2)
)
else:
optimizer = Adam(
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.22.2'
__version__ = '1.22.4'

0 comments on commit 726c11a

Please sign in to comment.