Skip to content

[refactor] create modeling blocks specific to AnimateDiff #8979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 3, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 25, 2024

What does this PR do?

Based on internal conversation with @DN6. I have left some inline questions as well

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6

@a-r-r-o-w a-r-r-o-w requested a review from DN6 July 25, 2024 22:54
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w mentioned this pull request Jul 28, 2024
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. I think the issues with the failing tests need to be addressed.

transformer_layers_per_block=transformer_layers_per_block[i],
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
# TODO(aryan): Can we reduce LOC here by creating a dictionary of common arguments and then passing **kwargs?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 WDYT about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give me an example here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, something like this:

init_kwargs = {
    "in_channels": input_channel,
    "prev_output_channel": prev_output_channel,
    "out_channels": output_channel,
    "temb_channels": time_embed_dim,
    "resolution_idx": i,
    "num_layers": reversed_layers_per_block[i] + 1,
    "resnet_eps": norm_eps,
    "resnet_act_fn": act_fn,
    "resnet_groups": norm_num_groups,
    "add_upsample": add_upsample,
    "temporal_num_attention_heads": reversed_motion_num_attention_heads[i],
    "temporal_max_seq_length": motion_max_seq_length,
    "temporal_transformer_layers_per_block": reverse_temporal_transformer_layers_per_block[i],
}

if up_block_type == "CrossAttnUpBlockMotion":
    up_block = CrossAttnUpBlockMotion(
        transformer_layers_per_block=reverse_transformer_layers_per_block[i],
        num_attention_heads=reversed_num_attention_heads[i],
        cross_attention_dim=reversed_cross_attention_dim[i],
        use_linear_projection=use_linear_projection,
        **init_kwargs,
    )
elif up_block_type == "UpBlockMotion":
    up_block = UpBlockMotion(**init_kwargs)
else:
    raise ValueError(
        "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`"
    )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so the large list of arguments is a problem with the UNet models and associated blocks.

This approach could work, but I would argue that the arguments to these blocks could actually be made much smaller by just setting defaults for certain parameters that are just unlikely to change. Perhaps we revisit when doing the UNet redesign?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure, let's roll with what we have right now for arguments. Can take this up later and look for suggestions

@@ -947,924 +854,6 @@ def forward(
return hidden_states


class DownBlockMotion(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some projects that might import these blocks directly.
https://github.com/search?q=%22from+diffusers.models.unets.unet_3d_blocks+import%22+language:Python&type=code

It would be good to create something similar to

class Transformer2DModelOutput(Transformer2DModelOutput):

class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):

dummy blocks that inherit from the real blocks and raise a deprecation warning.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, yes will do that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the deprecation blocks. Is this something to be handled in get_down_block and get_up_block as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that's fine.

@a-r-r-o-w a-r-r-o-w requested a review from DN6 August 2, 2024 09:11


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class AnimateDiffTransformer3DOutput(BaseOutput):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since transfomer is now a block, we don't need to define an output class for it.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one minor nit to fix regarding the transformer output. But LGTM otherwise 👍🏽

@DN6 DN6 merged commit fbe29c6 into main Aug 3, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the animatediff/refactor-modeling branch August 3, 2024 07:35
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* animatediff specific transformer model

* make style

* make fix-copies

* move blocks to unet motion model

* make style

* remove dummy object

* fix incorrectly passed param causing test failures

* rename model and output class

* fix sparsectrl imports

* remove todo comments

* remove temporal double self attn param from controlnet sparsectrl

* add deprecated versions of blocks

* apply suggestions from review

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants