Skip to content

Commit 749693b

Browse files
virginiafdezVirginia FernandezKumoLiu
authored andcommitted
Inferer modification - save_intermediates clashes with latent shape adjustment in latent diffusion inferers (Project-MONAI#8343)
Fixes Project-MONAI#8334 ### Description There was an if save_intermediates missing in the code that was trying to run crop of the latent spaces on the sample function of the Latent Diffusion Inferers (normal one and ControlNet one) even when intermediates aren't created. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Can-Zhao <volcanofly@gmail.com>
1 parent e8b500b commit 749693b

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

monai/inferers/inferer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,15 +1213,16 @@ def sample( # type: ignore[override]
12131213

12141214
if self.autoencoder_latent_shape is not None:
12151215
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1216-
latent_intermediates = [
1217-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1218-
]
1216+
if save_intermediates:
1217+
latent_intermediates = [
1218+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1219+
for l in latent_intermediates
1220+
]
12191221

12201222
decode = autoencoder_model.decode_stage_2_outputs
12211223
if isinstance(autoencoder_model, SPADEAutoencoderKL):
12221224
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
12231225
image = decode(latent / self.scale_factor)
1224-
12251226
if save_intermediates:
12261227
intermediates = []
12271228
for latent_intermediate in latent_intermediates:
@@ -1738,9 +1739,11 @@ def sample( # type: ignore[override]
17381739

17391740
if self.autoencoder_latent_shape is not None:
17401741
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1741-
latent_intermediates = [
1742-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1743-
]
1742+
if save_intermediates:
1743+
latent_intermediates = [
1744+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1745+
for l in latent_intermediates
1746+
]
17441747

17451748
decode = autoencoder_model.decode_stage_2_outputs
17461749
if isinstance(autoencoder_model, SPADEAutoencoderKL):

tests/inferers/test_controlnet_inferers.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def test_prediction_shape(
722722

723723
@parameterized.expand(LATENT_CNDM_TEST_CASES)
724724
@skipUnless(has_einops, "Requires einops")
725-
def test_sample_shape(
725+
def test_pred_shape(
726726
self,
727727
ae_model_type,
728728
autoencoder_params,
@@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(
11651165

11661166
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
11671167
@skipUnless(has_einops, "Requires einops")
1168-
def test_sample_shape_different_latents(
1168+
def test_shape_different_latents(
11691169
self,
11701170
ae_model_type,
11711171
autoencoder_params,
@@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
12421242
)
12431243
self.assertEqual(prediction.shape, latent_shape)
12441244

1245+
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
1246+
@skipUnless(has_einops, "Requires einops")
1247+
def test_sample_shape_different_latents(
1248+
self,
1249+
ae_model_type,
1250+
autoencoder_params,
1251+
dm_model_type,
1252+
stage_2_params,
1253+
controlnet_params,
1254+
input_shape,
1255+
latent_shape,
1256+
):
1257+
stage_1 = None
1258+
1259+
if ae_model_type == "AutoencoderKL":
1260+
stage_1 = AutoencoderKL(**autoencoder_params)
1261+
if ae_model_type == "VQVAE":
1262+
stage_1 = VQVAE(**autoencoder_params)
1263+
if ae_model_type == "SPADEAutoencoderKL":
1264+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
1265+
if dm_model_type == "SPADEDiffusionModelUNet":
1266+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
1267+
else:
1268+
stage_2 = DiffusionModelUNet(**stage_2_params)
1269+
controlnet = ControlNet(**controlnet_params)
1270+
1271+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1272+
stage_1.to(device)
1273+
stage_2.to(device)
1274+
controlnet.to(device)
1275+
stage_1.eval()
1276+
stage_2.eval()
1277+
controlnet.eval()
1278+
1279+
noise = torch.randn(latent_shape).to(device)
1280+
mask = torch.randn(input_shape).to(device)
1281+
scheduler = DDPMScheduler(num_train_timesteps=10)
1282+
# We infer the VAE shape
1283+
if ae_model_type == "VQVAE":
1284+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
1285+
else:
1286+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
1287+
1288+
inferer = ControlNetLatentDiffusionInferer(
1289+
scheduler=scheduler,
1290+
scale_factor=1.0,
1291+
ldm_latent_shape=list(latent_shape[2:]),
1292+
autoencoder_latent_shape=autoencoder_latent_shape,
1293+
)
1294+
scheduler.set_timesteps(num_inference_steps=10)
1295+
1296+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
1297+
input_shape_seg = list(input_shape)
1298+
if "label_nc" in stage_2_params.keys():
1299+
input_shape_seg[1] = stage_2_params["label_nc"]
1300+
else:
1301+
input_shape_seg[1] = autoencoder_params["label_nc"]
1302+
input_seg = torch.randn(input_shape_seg).to(device)
1303+
prediction, _ = inferer.sample(
1304+
autoencoder_model=stage_1,
1305+
diffusion_model=stage_2,
1306+
controlnet=controlnet,
1307+
cn_cond=mask,
1308+
input_noise=noise,
1309+
seg=input_seg,
1310+
save_intermediates=True,
1311+
)
1312+
else:
1313+
prediction = inferer.sample(
1314+
autoencoder_model=stage_1,
1315+
diffusion_model=stage_2,
1316+
input_noise=noise,
1317+
controlnet=controlnet,
1318+
cn_cond=mask,
1319+
save_intermediates=False,
1320+
)
1321+
self.assertEqual(prediction.shape, input_shape)
1322+
12451323
@skipUnless(has_einops, "Requires einops")
12461324
def test_incompatible_spade_setup(self):
12471325
stage_1 = SPADEAutoencoderKL(

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat(
714714

715715
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
716716
@skipUnless(has_einops, "Requires einops")
717-
def test_sample_shape_different_latents(
717+
def test_shape_different_latents(
718718
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
719719
):
720720
stage_1 = None
@@ -772,6 +772,66 @@ def test_sample_shape_different_latents(
772772
)
773773
self.assertEqual(prediction.shape, latent_shape)
774774

775+
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
776+
@skipUnless(has_einops, "Requires einops")
777+
def test_sample_shape_different_latents(
778+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
779+
):
780+
stage_1 = None
781+
782+
if ae_model_type == "AutoencoderKL":
783+
stage_1 = AutoencoderKL(**autoencoder_params)
784+
if ae_model_type == "VQVAE":
785+
stage_1 = VQVAE(**autoencoder_params)
786+
if ae_model_type == "SPADEAutoencoderKL":
787+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
788+
if dm_model_type == "SPADEDiffusionModelUNet":
789+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
790+
else:
791+
stage_2 = DiffusionModelUNet(**stage_2_params)
792+
793+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
794+
stage_1.to(device)
795+
stage_2.to(device)
796+
stage_1.eval()
797+
stage_2.eval()
798+
799+
noise = torch.randn(latent_shape).to(device)
800+
scheduler = DDPMScheduler(num_train_timesteps=10)
801+
# We infer the VAE shape
802+
if ae_model_type == "VQVAE":
803+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
804+
else:
805+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
806+
807+
inferer = LatentDiffusionInferer(
808+
scheduler=scheduler,
809+
scale_factor=1.0,
810+
ldm_latent_shape=list(latent_shape[2:]),
811+
autoencoder_latent_shape=autoencoder_latent_shape,
812+
)
813+
scheduler.set_timesteps(num_inference_steps=10)
814+
815+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
816+
input_shape_seg = list(input_shape)
817+
if "label_nc" in stage_2_params.keys():
818+
input_shape_seg[1] = stage_2_params["label_nc"]
819+
else:
820+
input_shape_seg[1] = autoencoder_params["label_nc"]
821+
input_seg = torch.randn(input_shape_seg).to(device)
822+
prediction, _ = inferer.sample(
823+
autoencoder_model=stage_1,
824+
diffusion_model=stage_2,
825+
input_noise=noise,
826+
save_intermediates=True,
827+
seg=input_seg,
828+
)
829+
else:
830+
prediction = inferer.sample(
831+
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
832+
)
833+
self.assertEqual(prediction.shape, input_shape)
834+
775835
@skipUnless(has_einops, "Requires einops")
776836
def test_incompatible_spade_setup(self):
777837
stage_1 = SPADEAutoencoderKL(

0 commit comments

Comments
 (0)