Skip to content

Commit 6e5b374

Browse files
a-r-r-o-wDN6
authored andcommitted
[refactor] create modeling blocks specific to AnimateDiff (#8979)
* animatediff specific transformer model * make style * make fix-copies * move blocks to unet motion model * make style * remove dummy object * fix incorrectly passed param causing test failures * rename model and output class * fix sparsectrl imports * remove todo comments * remove temporal double self attn param from controlnet sparsectrl * add deprecated versions of blocks * apply suggestions from review * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent c6ac793 commit 6e5b374

File tree

3 files changed

+1215
-1076
lines changed

3 files changed

+1215
-1076
lines changed

src/diffusers/models/controlnet_sparsectrl.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@
3232
from .modeling_utils import ModelMixin
3333
from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
3434
from .unets.unet_2d_condition import UNet2DConditionModel
35-
from .unets.unet_3d_blocks import (
36-
CrossAttnDownBlockMotion,
37-
DownBlockMotion,
38-
)
35+
from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
3936

4037

4138
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -317,7 +314,6 @@ def __init__(
317314
temporal_num_attention_heads=motion_num_attention_heads[i],
318315
temporal_max_seq_length=motion_max_seq_length,
319316
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
320-
temporal_double_self_attention=False,
321317
)
322318
elif down_block_type == "DownBlockMotion":
323319
down_block = DownBlockMotion(
@@ -334,7 +330,6 @@ def __init__(
334330
add_downsample=not is_final_block,
335331
temporal_num_attention_heads=motion_num_attention_heads[i],
336332
temporal_max_seq_length=motion_max_seq_length,
337-
temporal_double_self_attention=False,
338333
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
339334
)
340335
else:

0 commit comments

Comments
 (0)