Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ControlNetMaisi #8005

Merged
33 changes: 15 additions & 18 deletions monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,15 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, cast
from typing import Sequence

import torch

from monai.utils import optional_import
from monai.networks.nets.controlnet import ControlNet
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding

ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
)

if TYPE_CHECKING:
from generative.networks.nets.controlnet import ControlNet as ControlNetType
else:
ControlNetType = cast(type, ControlNet)


class ControlNetMaisi(ControlNetType):
class ControlNetMaisi(ControlNet):
"""
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
Diffusion Models" (https://arxiv.org/abs/2302.05543)
Expand All @@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
classes.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
use_checkpointing: if True, use activation checkpointing to save memory.
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -71,10 +64,12 @@ def __init__(
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
conditioning_embedding_in_channels: int = 1,
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
use_checkpointing: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__(
spatial_dims,
Expand All @@ -91,9 +86,11 @@ def __init__(
cross_attention_dim,
num_class_embeds,
upcast_attention,
use_flash_attention,
conditioning_embedding_in_channels,
conditioning_embedding_num_channels,
include_fc,
use_combined_linear,
use_flash_attention,
)
self.use_checkpointing = use_checkpointing

Expand All @@ -105,7 +102,7 @@ def forward(
conditioning_scale: float = 1.0,
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
h = self._apply_initial_convolution(x)
if self.use_checkpointing:
Expand Down
21 changes: 11 additions & 10 deletions tests/test_controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
from monai.networks import eval_mode
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
_, has_einops = optional_import("einops")

TEST_CASES = [
[
Expand Down Expand Up @@ -103,16 +101,17 @@
TEST_CASES_ERROR = [
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None},
"ControlNet expects dimension of the cross-attention conditioning "
"(cross_attention_dim) when using with_conditioning.",
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"to be specified when with_conditioning=True.",
],
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2},
"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.",
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.",
],
[
{"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16},
"ControlNet expects all num_channels being multiple of norm_num_groups",
f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got"
f" channels={(8, 16)} and norm_num_groups={16}",
],
[
{
Expand All @@ -122,16 +121,17 @@
"attention_levels": (True,),
"norm_num_groups": 8,
},
"ControlNet expects num_channels being same size of attention_levels",
f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got "
f"channels={(8, 16)} and attention_levels={(True,)}",
],
]


@SkipIfBeforePyTorchVersion((2, 0))
@skipUnless(has_generative, "monai-generative required")
class TestControlNet(unittest.TestCase):

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
Expand All @@ -145,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_
self.assertEqual(result[1].shape, expected_shape)

@parameterized.expand(TEST_CASES_CONDITIONAL)
@skipUnless(has_einops, "Requires einops")
def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
Expand Down
Loading