Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit a473b5f

Browse files
virginiafdezvirginiafdez
andauthored
Fix bug where save_intermediates is sometimes ignored. (#465)
Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk>
1 parent a9b17d4 commit a473b5f

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

generative/inferers/inferer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,10 @@ def sample(
457457

458458
if self.autoencoder_latent_shape is not None:
459459
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
460-
latent_intermediates = [
461-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
462-
]
460+
if save_intermediates:
461+
latent_intermediates = [
462+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
463+
]
463464

464465
decode = autoencoder_model.decode_stage_2_outputs
465466
if isinstance(autoencoder_model, SPADEAutoencoderKL):
@@ -991,9 +992,10 @@ def sample(
991992

992993
if self.autoencoder_latent_shape is not None:
993994
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
994-
latent_intermediates = [
995-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
996-
]
995+
if save_intermediates:
996+
latent_intermediates = [
997+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
998+
]
997999

9981000
decode = autoencoder_model.decode_stage_2_outputs
9991001
if isinstance(autoencoder_model, SPADEAutoencoderKL):

0 commit comments

Comments
 (0)