Skip to content
Merged
17 changes: 10 additions & 7 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
image = decode(latent / self.scale_factor)

if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
Expand Down Expand Up @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down
82 changes: 80 additions & 2 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def test_prediction_shape(

@parameterized.expand(LATENT_CNDM_TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape(
def test_pred_shape(
self,
ae_model_type,
autoencoder_params,
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(

@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
def test_shape_different_latents(
self,
ae_model_type,
autoencoder_params,
Expand Down Expand Up @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
self,
ae_model_type,
autoencoder_params,
dm_model_type,
stage_2_params,
controlnet_params,
input_shape,
latent_shape,
):
stage_1 = None

if ae_model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if ae_model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
stage_2 = DiffusionModelUNet(**stage_2_params)
controlnet = ControlNet(**controlnet_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
controlnet.to(device)
stage_1.eval()
stage_2.eval()
controlnet.eval()

noise = torch.randn(latent_shape).to(device)
mask = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
# We infer the VAE shape
if ae_model_type == "VQVAE":
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
else:
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]

inferer = ControlNetLatentDiffusionInferer(
scheduler=scheduler,
scale_factor=1.0,
ldm_latent_shape=list(latent_shape[2:]),
autoencoder_latent_shape=autoencoder_latent_shape,
)
scheduler.set_timesteps(num_inference_steps=10)

if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
input_shape_seg = list(input_shape)
if "label_nc" in stage_2_params.keys():
input_shape_seg[1] = stage_2_params["label_nc"]
else:
input_shape_seg[1] = autoencoder_params["label_nc"]
input_seg = torch.randn(input_shape_seg).to(device)
prediction, _ = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
controlnet=controlnet,
cn_cond=mask,
input_noise=noise,
seg=input_seg,
save_intermediates=True,
)
else:
prediction = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
input_noise=noise,
controlnet=controlnet,
cn_cond=mask,
save_intermediates=False,
)
self.assertEqual(prediction.shape, input_shape)

@skipUnless(has_einops, "Requires einops")
def test_incompatible_spade_setup(self):
stage_1 = SPADEAutoencoderKL(
Expand Down
62 changes: 61 additions & 1 deletion tests/inferers/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat(

@parameterized.expand(TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
def test_shape_different_latents(
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
):
stage_1 = None
Expand Down Expand Up @@ -772,6 +772,66 @@ def test_sample_shape_different_latents(
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
):
stage_1 = None

if ae_model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if ae_model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
# We infer the VAE shape
if ae_model_type == "VQVAE":
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
else:
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]

inferer = LatentDiffusionInferer(
scheduler=scheduler,
scale_factor=1.0,
ldm_latent_shape=list(latent_shape[2:]),
autoencoder_latent_shape=autoencoder_latent_shape,
)
scheduler.set_timesteps(num_inference_steps=10)

if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
input_shape_seg = list(input_shape)
if "label_nc" in stage_2_params.keys():
input_shape_seg[1] = stage_2_params["label_nc"]
else:
input_shape_seg[1] = autoencoder_params["label_nc"]
input_seg = torch.randn(input_shape_seg).to(device)
prediction, _ = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
input_noise=noise,
save_intermediates=True,
seg=input_seg,
)
else:
prediction = inferer.sample(
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
)
self.assertEqual(prediction.shape, input_shape)

@skipUnless(has_einops, "Requires einops")
def test_incompatible_spade_setup(self):
stage_1 = SPADEAutoencoderKL(
Expand Down
Loading