Skip to content

Commit

Permalink
reduce memory usage in sample image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 24, 2024
1 parent 623017f commit e3c43bd
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,6 @@ def sample_images(
except Exception:
pass

org_vae_device = vae.device # will be on cpu
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device

if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
Expand Down Expand Up @@ -450,8 +447,6 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)

vae.to(org_vae_device)

clean_memory_on_device(accelerator.device)


Expand Down Expand Up @@ -531,12 +526,19 @@ def sample_image_inference(
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)

# sample image
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
clean_memory_on_device(accelerator.device)
with accelerator.autocast():
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)

# latent to image
with torch.no_grad():
image = vae.decode(latents)
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)

image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
Expand Down

0 comments on commit e3c43bd

Please sign in to comment.