Skip to content
Merged
16 changes: 10 additions & 6 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,9 +1202,11 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down Expand Up @@ -1727,9 +1729,11 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down
Loading