Skip to content

[WIP] AnimateDiff SDXL #6195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)

if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
Expand Down Expand Up @@ -252,6 +253,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
Expand Down
106 changes: 93 additions & 13 deletions src/diffusers/models/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
activation_fn: str = "geglu",
norm_num_groups: int = 32,
max_seq_length: int = 32,
positional_embeddings: Optional[str] = None,
):
super().__init__()
self.motion_modules = nn.ModuleList([])
Expand All @@ -72,7 +73,7 @@ def __init__(
attention_bias=attention_bias,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads,
positional_embeddings="sinusoidal",
positional_embeddings=positional_embeddings,
num_positional_embeddings=max_seq_length,
)
)
Expand All @@ -94,7 +95,7 @@ def __init__(

Args:
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each UNet block.
The tuple of output channels for each UNet block.
motion_layers_per_block (`int`, *optional*, defaults to 2):
The number of motion layers per UNet block.
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -125,6 +126,7 @@ def __init__(
num_attention_heads=motion_num_attention_heads,
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block,
positional_embeddings="sinusoidal",
)
)

Expand All @@ -138,6 +140,7 @@ def __init__(
num_attention_heads=motion_num_attention_heads,
layers_per_block=motion_mid_block_layers_per_block,
max_seq_length=motion_max_seq_length,
positional_embeddings="sinusoidal",
)
else:
self.mid_block = None
Expand All @@ -156,6 +159,7 @@ def __init__(
num_attention_heads=motion_num_attention_heads,
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block + 1,
positional_embeddings="sinusoidal",
)
)

Expand Down Expand Up @@ -183,6 +187,8 @@ def __init__(
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
Expand All @@ -203,13 +209,21 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
addition_embed_type_num_heads: int = 64,
):
super().__init__()

Expand All @@ -231,9 +245,22 @@ def __init__(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)

if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)

if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

# input
conv_in_kernel = 3
conv_out_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
Expand All @@ -253,13 +280,26 @@ def __init__(
if encoder_hid_dim_type is None:
self.encoder_hid_proj = None

if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)

if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)

if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
Expand All @@ -269,21 +309,22 @@ def __init__(

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)

Expand All @@ -292,13 +333,15 @@ def __init__(
self.mid_block = UNetMidBlockCrossAttnMotion(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
)
Expand All @@ -307,13 +350,15 @@ def __init__(
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
)

# count how many layers upsample the images
Expand All @@ -322,6 +367,9 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))

output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
Expand All @@ -340,7 +388,7 @@ def __init__(

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
Expand All @@ -349,13 +397,14 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -411,7 +460,7 @@ def from_unet2d(
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]

# Need this for backwards compatibility with UNet2DConditionModel checkpoints
if not config.get("num_attention_heads"):
if config.get("attention_head_dim", None):
config["num_attention_heads"] = config["attention_head_dim"]

model = cls.from_config(config)
Expand All @@ -438,7 +487,16 @@ def from_unet2d(
model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())

model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())
# model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())

# TODO(aryan): Fix size mismatch
have = {}
for x in model.mid_block.attentions.state_dict().keys():
if x in unet.mid_block.attentions.state_dict().keys():
have[x] = unet.mid_block.attentions.state_dict()[x].reshape(
model.mid_block.attentions.state_dict()[x].shape
)
model.mid_block.attentions.load_state_dict(have)

if unet.conv_norm_out is not None:
model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
Expand Down Expand Up @@ -772,12 +830,34 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

if self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)

if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
Expand Down
Loading