-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Pipeline] AnimateDiff SDXL #6721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 45 commits
56ba44b
7ae7bc8
01f5978
2562500
736a224
4a2b9de
eb060e0
54cd75c
c01d2c2
3d45dc1
60364ea
0db8340
bf2cd49
389adaa
f471e3c
ba4f9f4
504e958
a3fb232
512d346
21e6af1
1fa606f
93fc848
5bbd8ef
9f69127
5919f37
a09355a
306902f
5ba4383
feb458a
7c3807d
dd996ad
53f5815
971852f
dc0fd88
75bd4e8
2b618ea
124d1c9
1081788
f887625
4772b0d
7df8fab
b68037d
90ebe25
3c662ba
c8b9d73
33d5a18
43e873f
3a9cd0f
cf3ebc9
1a8d76c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,13 +211,18 @@ def __init__( | |
norm_num_groups: int = 32, | ||
norm_eps: float = 1e-5, | ||
cross_attention_dim: int = 1280, | ||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | ||
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, | ||
use_linear_projection: bool = False, | ||
num_attention_heads: Union[int, Tuple[int, ...]] = 8, | ||
motion_max_seq_length: int = 32, | ||
motion_num_attention_heads: int = 8, | ||
use_motion_mid_block: int = True, | ||
encoder_hid_dim: Optional[int] = None, | ||
encoder_hid_dim_type: Optional[str] = None, | ||
addition_embed_type: Optional[str] = None, | ||
addition_time_embed_dim: Optional[int] = None, | ||
projection_class_embeddings_input_dim: Optional[int] = None, | ||
time_cond_proj_dim: Optional[int] = None, | ||
): | ||
super().__init__() | ||
|
@@ -240,6 +245,21 @@ def __init__( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | ||
) | ||
|
||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | ||
raise ValueError( | ||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | ||
) | ||
|
||
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | ||
raise ValueError( | ||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | ||
) | ||
|
||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: | ||
for layer_number_per_block in transformer_layers_per_block: | ||
if isinstance(layer_number_per_block, list): | ||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") | ||
|
||
# input | ||
conv_in_kernel = 3 | ||
conv_out_kernel = 3 | ||
|
@@ -260,13 +280,26 @@ def __init__( | |
if encoder_hid_dim_type is None: | ||
self.encoder_hid_proj = None | ||
|
||
if addition_embed_type == "text_time": | ||
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) | ||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also load add_embedding in from_unet2d? Something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we take a look at current unet_2d_condition.py modelling code, the team has refactored out these changes into separate helper functions. I think that because unet_motion_model.py is mostly a copy of that, we can adapt those changes here and therefore all the functionality one would need. We can take it up in a future PR in my opinion (also I'm afraid I will not have time to test things thoroughly if we do it here). |
||
|
||
# class embedding | ||
self.down_blocks = nn.ModuleList([]) | ||
self.up_blocks = nn.ModuleList([]) | ||
|
||
if isinstance(num_attention_heads, int): | ||
num_attention_heads = (num_attention_heads,) * len(down_block_types) | ||
|
||
if isinstance(cross_attention_dim, int): | ||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | ||
|
||
if isinstance(layers_per_block, int): | ||
layers_per_block = [layers_per_block] * len(down_block_types) | ||
|
||
if isinstance(transformer_layers_per_block, int): | ||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | ||
|
||
# down | ||
output_channel = block_out_channels[0] | ||
for i, down_block_type in enumerate(down_block_types): | ||
|
@@ -276,21 +309,22 @@ def __init__( | |
|
||
down_block = get_down_block( | ||
down_block_type, | ||
num_layers=layers_per_block, | ||
num_layers=layers_per_block[i], | ||
in_channels=input_channel, | ||
out_channels=output_channel, | ||
temb_channels=time_embed_dim, | ||
add_downsample=not is_final_block, | ||
resnet_eps=norm_eps, | ||
resnet_act_fn=act_fn, | ||
resnet_groups=norm_num_groups, | ||
cross_attention_dim=cross_attention_dim, | ||
cross_attention_dim=cross_attention_dim[i], | ||
num_attention_heads=num_attention_heads[i], | ||
downsample_padding=downsample_padding, | ||
use_linear_projection=use_linear_projection, | ||
dual_cross_attention=False, | ||
temporal_num_attention_heads=motion_num_attention_heads, | ||
temporal_max_seq_length=motion_max_seq_length, | ||
transformer_layers_per_block=transformer_layers_per_block[i], | ||
) | ||
self.down_blocks.append(down_block) | ||
|
||
|
@@ -302,13 +336,14 @@ def __init__( | |
resnet_eps=norm_eps, | ||
resnet_act_fn=act_fn, | ||
output_scale_factor=mid_block_scale_factor, | ||
cross_attention_dim=cross_attention_dim, | ||
cross_attention_dim=cross_attention_dim[-1], | ||
num_attention_heads=num_attention_heads[-1], | ||
resnet_groups=norm_num_groups, | ||
dual_cross_attention=False, | ||
use_linear_projection=use_linear_projection, | ||
temporal_num_attention_heads=motion_num_attention_heads, | ||
temporal_max_seq_length=motion_max_seq_length, | ||
transformer_layers_per_block=transformer_layers_per_block[-1], | ||
) | ||
|
||
else: | ||
|
@@ -318,11 +353,12 @@ def __init__( | |
resnet_eps=norm_eps, | ||
resnet_act_fn=act_fn, | ||
output_scale_factor=mid_block_scale_factor, | ||
cross_attention_dim=cross_attention_dim, | ||
cross_attention_dim=cross_attention_dim[-1], | ||
num_attention_heads=num_attention_heads[-1], | ||
resnet_groups=norm_num_groups, | ||
dual_cross_attention=False, | ||
use_linear_projection=use_linear_projection, | ||
transformer_layers_per_block=transformer_layers_per_block[-1], | ||
) | ||
|
||
# count how many layers upsample the images | ||
|
@@ -331,6 +367,9 @@ def __init__( | |
# up | ||
reversed_block_out_channels = list(reversed(block_out_channels)) | ||
reversed_num_attention_heads = list(reversed(num_attention_heads)) | ||
reversed_layers_per_block = list(reversed(layers_per_block)) | ||
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | ||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | ||
|
||
output_channel = reversed_block_out_channels[0] | ||
for i, up_block_type in enumerate(up_block_types): | ||
|
@@ -349,7 +388,7 @@ def __init__( | |
|
||
up_block = get_up_block( | ||
up_block_type, | ||
num_layers=layers_per_block + 1, | ||
num_layers=reversed_layers_per_block[i] + 1, | ||
in_channels=input_channel, | ||
out_channels=output_channel, | ||
prev_output_channel=prev_output_channel, | ||
|
@@ -358,13 +397,14 @@ def __init__( | |
resnet_eps=norm_eps, | ||
resnet_act_fn=act_fn, | ||
resnet_groups=norm_num_groups, | ||
cross_attention_dim=cross_attention_dim, | ||
cross_attention_dim=reversed_cross_attention_dim[i], | ||
num_attention_heads=reversed_num_attention_heads[i], | ||
dual_cross_attention=False, | ||
resolution_idx=i, | ||
use_linear_projection=use_linear_projection, | ||
temporal_num_attention_heads=motion_num_attention_heads, | ||
temporal_max_seq_length=motion_max_seq_length, | ||
transformer_layers_per_block=reversed_transformer_layers_per_block[i], | ||
) | ||
self.up_blocks.append(up_block) | ||
prev_output_channel = output_channel | ||
|
@@ -835,6 +875,28 @@ def forward( | |
t_emb = t_emb.to(dtype=self.dtype) | ||
|
||
emb = self.time_embedding(t_emb, timestep_cond) | ||
aug_emb = None | ||
|
||
if self.config.addition_embed_type == "text_time": | ||
if "text_embeds" not in added_cond_kwargs: | ||
raise ValueError( | ||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | ||
) | ||
|
||
text_embeds = added_cond_kwargs.get("text_embeds") | ||
if "time_ids" not in added_cond_kwargs: | ||
raise ValueError( | ||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | ||
) | ||
time_ids = added_cond_kwargs.get("time_ids") | ||
time_embeds = self.add_time_proj(time_ids.flatten()) | ||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | ||
|
||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | ||
add_embeds = add_embeds.to(emb.dtype) | ||
aug_emb = self.add_embedding(add_embeds) | ||
|
||
emb = emb if aug_emb is None else emb + aug_emb | ||
emb = emb.repeat_interleave(repeats=num_frames, dim=0) | ||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.