-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[refactor] create modeling blocks specific to AnimateDiff #8979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. I think the issues with the failing tests need to be addressed.
transformer_layers_per_block=transformer_layers_per_block[i], | ||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], | ||
) | ||
# TODO(aryan): Can we reduce LOC here by creating a dictionary of common arguments and then passing **kwargs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 WDYT about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give me an example here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, something like this:
init_kwargs = {
"in_channels": input_channel,
"prev_output_channel": prev_output_channel,
"out_channels": output_channel,
"temb_channels": time_embed_dim,
"resolution_idx": i,
"num_layers": reversed_layers_per_block[i] + 1,
"resnet_eps": norm_eps,
"resnet_act_fn": act_fn,
"resnet_groups": norm_num_groups,
"add_upsample": add_upsample,
"temporal_num_attention_heads": reversed_motion_num_attention_heads[i],
"temporal_max_seq_length": motion_max_seq_length,
"temporal_transformer_layers_per_block": reverse_temporal_transformer_layers_per_block[i],
}
if up_block_type == "CrossAttnUpBlockMotion":
up_block = CrossAttnUpBlockMotion(
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
cross_attention_dim=reversed_cross_attention_dim[i],
use_linear_projection=use_linear_projection,
**init_kwargs,
)
elif up_block_type == "UpBlockMotion":
up_block = UpBlockMotion(**init_kwargs)
else:
raise ValueError(
"Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so the large list of arguments is a problem with the UNet models and associated blocks.
This approach could work, but I would argue that the arguments to these blocks could actually be made much smaller by just setting defaults for certain parameters that are just unlikely to change. Perhaps we revisit when doing the UNet redesign?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes sure, let's roll with what we have right now for arguments. Can take this up later and look for suggestions
@@ -947,924 +854,6 @@ def forward( | |||
return hidden_states | |||
|
|||
|
|||
class DownBlockMotion(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some projects that might import these blocks directly.
https://github.com/search?q=%22from+diffusers.models.unets.unet_3d_blocks+import%22+language:Python&type=code
It would be good to create something similar to
class Transformer2DModelOutput(Transformer2DModelOutput): |
diffusers/src/diffusers/loaders/lora_pipeline.py
Line 1743 in ea1b4ea
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): |
dummy blocks that inherit from the real blocks and raise a deprecation warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah i see, yes will do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the deprecation blocks. Is this something to be handled in get_down_block
and get_up_block
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that's fine.
|
||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@dataclass | ||
class AnimateDiffTransformer3DOutput(BaseOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since transfomer is now a block, we don't need to define an output class for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one minor nit to fix regarding the transformer output. But LGTM otherwise 👍🏽
* animatediff specific transformer model * make style * make fix-copies * move blocks to unet motion model * make style * remove dummy object * fix incorrectly passed param causing test failures * rename model and output class * fix sparsectrl imports * remove todo comments * remove temporal double self attn param from controlnet sparsectrl * add deprecated versions of blocks * apply suggestions from review * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Based on internal conversation with @DN6. I have left some inline questions as well
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6