Skip to content

[WIP] AnimateDiff SDXL #6195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed

Conversation

a-r-r-o-w
Copy link
Member

What does this PR do?

Attempt at integrating https://github.com/guoyww/AnimateDiff/tree/sdxl.

Relevant discussion: #5928 (comment)

Fixes #6158.

Before submitting

Who can review?

@sayakpaul, @patrickvonplaten, @DN6

@a-r-r-o-w
Copy link
Member Author

@EdoardoBotta Would you be interested in working on this as well?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Dec 18, 2023

Apologies for the slow progress on this 😅 I was having trouble loading the checkpoint files due to key mismatches in the state dict due to a few implementation differences between the original AnimateDiff codebase and diffusers, which took me a while to understand, but seems like I've gotten it working. To do so, I've added a MotionAdapterXL class because I didn't want to break any of the existing functionality while testing. We can refactor it out if required in future commits. This allows me to load the state dict easily with the following code:

Code
import torch
from diffusers.models.unet_motion_model import MotionAdapterXL

config = {
    "_class_name": "MotionAdapter",
    "_diffusers_version": "0.22.0.dev0",
    "block_out_channels": [
        320,
        640,
        1280,
    ],
    "motion_activation_fn": "geglu",
    "motion_attention_bias": False,
    "motion_cross_attention_dim": None,
    "motion_layers_per_block": 2,
    "motion_max_seq_length": 24,
    "motion_mid_block_layers_per_block": 1,
    "motion_norm_num_groups": 32,
    "motion_num_attention_heads": 8,
    "use_motion_mid_block": False,
    "temporal_position_encoding": True,
    "temporal_position_encoding_max_len": 32,
}

adapter = MotionAdapterXL.from_config(config, torch_dtype=torch.float16)

ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, mmap=True, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())

for key in ckpt_keys:
    old_key = key[:]
    key = key.replace("temporal_transformer.", "")
    if 'attention_blocks' in key:
        loc = key.find('attention_blocks') + len('attention_blocks.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('attention_blocks.', 'attn')
    if 'ff_norm' in key:
        key = key.replace('ff_norm', 'norms.2')
    if 'norms' in key:
        loc = key.find('norms') + len('norms.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('norms.', 'norm')
    if 'pos_encoder' in key:
        key = key.replace('pos_encoder', 'pos_embed')
    ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)

adapter.load_state_dict(ckpt_state_dict)

@sayakpaul I'm not sure how to convert this checkpoint into a diffusers acceptable format if it's already supported. For the time being, what I have above works but let me know what would be a better way to do this. Thanks

@sayakpaul
Copy link
Member

Cc @DN6 here.

@a-r-r-o-w
Copy link
Member Author

WIP Colab notebook. It doesn't work yet because UNetMotionModel errors out while loading from the unet state dict of the sdxl model. Trying to figure out how to fix it.

@sayakpaul
Copy link
Member

@a-r-r-o-w thanks for your contributions thus far. From #6195 (comment), I thought we were able to load the state dict appropriately or isn't that the case anymore?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Dec 18, 2023

@a-r-r-o-w thanks for your contributions thus far. From #6195 (comment), I thought we were able to load the state dict appropriately or isn't that the case anymore?

We are indeed able to load the state dict, but for the Motion Adapter. Currently, it fails loading the state dict of the SDXL UNet. Attaching some logs from the notebook.

Logs
The config attributes {'center_input_sample': False, 'flip_sin_to_cos': True, 'freq_shift': 0, 'mid_block_type': 'UNetMidBlockCrossAttnMotion', 'only_cross_attention': False, 'transformer_layers_per_block': [1, 2, 10], 'attention_head_dim': [5, 10, 20], 'dual_cross_attention': False, 'class_embed_type': None, 'addition_embed_type': 'text_time', 'addition_time_embed_dim': 256, 'num_class_embeds': None, 'upcast_attention': None, 'resnet_time_scale_shift': 'default', 'resnet_skip_time_act': False, 'resnet_out_scale_factor': 1.0, 'time_embedding_type': 'positional', 'time_embedding_dim': None, 'time_embedding_act_fn': None, 'timestep_post_act': None, 'time_cond_proj_dim': None, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'projection_class_embeddings_input_dim': 2816, 'class_embeddings_concat': False, 'mid_block_only_cross_attention': None, 'cross_attention_norm': None, 'addition_embed_type_num_heads': 64} were passed to UNetMotionModelXL, but are not expected and will be ignored. Please verify your config.json configuration file.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-e5d1262a4c3a> in <cell line: 1>()
----> 1 pipe = AnimateDiffXLPipeline(
      2     vae=vae,
      3     text_encoder=text_encoder,
      4     text_encoder_2=text_encoder_2,
      5     tokenizer=tokenizer,

2 frames
<ipython-input-19-3940d285b690> in __init__(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, motion_adapter, scheduler, feature_extractor, image_encoder)
    143     ):
    144         super().__init__()
--> 145         unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
    146 
    147         self.register_modules(

<ipython-input-14-b230f02e66e7> in from_unet2d(cls, unet, motion_adapter, load_weights)
    553             model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
    554             if hasattr(model.down_blocks[i], "attentions"):
--> 555                 model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
    556             if model.down_blocks[i].downsamplers:
    557                 model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
   2150 
   2151         if len(error_msgs) > 0:
-> 2152             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2154         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ModuleList:
	Unexpected key(s) in state_dict: "0.transformer_blocks.1.norm1.weight", "0.transformer_blocks.1.norm1.bias", "0.transformer_blocks.1.attn1.to_q.weight", "0.transformer_blocks.1.attn1.to_k.weight", "0.transformer_blocks.1.attn1.to_v.weight", "0.transformer_blocks.1.attn1.to_out.0.weight", "0.transformer_blocks.1.attn1.to_out.0.bias", "0.transformer_blocks.1.norm2.weight", "0.transformer_blocks.1.norm2.bias", "0.transformer_blocks.1.attn2.to_q.weight", "0.transformer_blocks.1.attn2.to_k.weight", "0.transformer_blocks.1.attn2.to_v.weight", "0.transformer_blocks.1.attn2.to_out.0.weight", "0.transformer_blocks.1.attn2.to_out.0.bias", "0.transformer_blocks.1.norm3.weight", "0.transformer_blocks.1.norm3.bias", "0.transformer_blocks.1.ff.net.0.proj.weight", "0.transformer_blocks.1.ff.net.0.proj.bias", "0.transformer_blocks.1.ff.net.2.weight", "0.transformer_blocks.1.ff.net.2.bias", "1.transformer_blocks.1.norm1.weight", "1.transformer_blocks.1.norm1.bias", "1.transformer_blocks.1.attn1.to_q.weight", "1.transformer_blocks.1.attn1.to_k.weight", "1.transformer_blocks.1.attn1.to_v.weight", "1.transformer_blocks.1.attn1.to_out.0.weight", "1.transformer_blocks.1.attn1.to_out.0.bias", "1.transformer_blocks.1.norm2.weight", "1.transformer_blocks.1.norm2.bias", "1.transformer_blocks.1.attn2.to_q.weight", "1.transformer_blocks.1.attn2.to_k.weight", "1.transformer_blocks.1.attn2.to_v.weight", "1.transformer_blocks.1.attn2.to_out.0.weight", "1.transformer_blocks.1.attn2.to_out.0.bias", "1.tr...
Relevant Code From UNetMotionModel
# ... from_unet2d()
        if has_motion_adapter:
            config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
            config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
            config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]

        # Need this for backwards compatibility with UNet2DConditionModel checkpoints
        if not config.get("num_attention_heads"):
            config["num_attention_heads"] = config["attention_head_dim"]

        model = cls.from_config(config)

        if not load_weights:
            return model

        model.conv_in.load_state_dict(unet.conv_in.state_dict())
        model.time_proj.load_state_dict(unet.time_proj.state_dict())
        model.time_embedding.load_state_dict(unet.time_embedding.state_dict())

        for i, down_block in enumerate(unet.down_blocks):
            # fails here
            model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
            if hasattr(model.down_blocks[i], "attentions"):
                model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
            if model.down_blocks[i].downsamplers:
                model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())

        for i, up_block in enumerate(unet.up_blocks):
            model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
            if hasattr(model.up_blocks[i], "attentions"):
                model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
            if model.up_blocks[i].upsamplers:
                model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())

        model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
        model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())

@sayakpaul
Copy link
Member

I see. Did we add any state dict conversion code for the motion module?

@a-r-r-o-w
Copy link
Member Author

Nope, I assumed the unet loading would work for SDXL seamlessly but I think we'll have to add a few changes to the UNetMotionModel class as well.

@a-r-r-o-w
Copy link
Member Author

Made a little progress and have updated the Colab notebook linked above. So far, we can load the motion adapter correctly, UNetMotionModel state dict (albeit very hackily and I still have to clean it up), and manage to make the call to pipe.__call__. I have a new set of errors when trying to run inference related to attention dims not matching, which I believe is due to incorrect UNet initialization.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Dec 19, 2023

Hi @a-r-r-o-w I know this is still WIP. But I left a few comments that should simplify the code a bit.

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
@a-r-r-o-w
Copy link
Member Author

After applying changes from the review, the motion adapter can be successfully loaded with this:

Code
import torch
from diffusers.models.unet_motion_model import MotionAdapter

config = {
    "_class_name": "MotionAdapter",
    "_diffusers_version": "0.22.0.dev0",
    "block_out_channels": [
        320,
        640,
        1280,
    ],
    "motion_activation_fn": "geglu",
    "motion_attention_bias": False,
    "motion_cross_attention_dim": None,
    "motion_layers_per_block": 2,
    "motion_max_seq_length": 32,
    "motion_mid_block_layers_per_block": 1,
    "motion_norm_num_groups": 32,
    "motion_num_attention_heads": 8,
    "use_motion_mid_block": False,
}
adapter = MotionAdapter.from_config(config, torch_dtype=torch.float16)

ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())

for key in ckpt_keys:
    old_key = key[:]
    key = key.replace("temporal_transformer.", "")
    if 'attention_blocks' in key:
        loc = key.find('attention_blocks') + len('attention_blocks.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('attention_blocks.', 'attn')
    if 'ff_norm' in key:
        key = key.replace('ff_norm', 'norms.2')
    if 'norms' in key:
        loc = key.find('norms') + len('norms.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('norms.', 'norm')
    if 'pos_encoder' in key:
        key = key.replace('pos_encoder', 'pos_embed')
        replace = ''
        if 'attn1' in key:
            replace = 'attn1.'
        else:
            replace = 'attn2.'
        key = key.replace(replace, '')
    ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)

adapter.load_state_dict(ckpt_state_dict)

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Dec 21, 2023

Apologies for the delay here... Between university and work, I haven't been able to put in a lot of time yet.

@DN6 I'm having trouble getting the inference running. I believe I've got the entire model architecture mapped out correctly and the code for loading the state dicts seems to be working well. However, when executing the inference loop, the batch size in the attention query seems to be incorrect and doesn't even properly divide the input size. Could you help me out?

Code
import torch
from diffusers import AnimateDiffPipeline, AutoencoderKL, DDIMScheduler, EulerDiscreteScheduler
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.models import UNet2DConditionModel, UNet3DConditionModel, UNetMotionModel
from diffusers.pipelines.animatediff.pipeline_animatediff_xl import AnimateDiffXLPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection

config = {
    "_class_name": "MotionAdapter",
    "_diffusers_version": "0.22.0.dev0",
    "block_out_channels": [
        320,
        640,
        1280,
    ],
    "motion_activation_fn": "geglu",
    "motion_attention_bias": False,
    "motion_cross_attention_dim": None,
    "motion_layers_per_block": 2,
    "motion_max_seq_length": 32,
    "motion_mid_block_layers_per_block": 1,
    "motion_norm_num_groups": 32,
    "motion_num_attention_heads": 8,
    "use_motion_mid_block": False,
    "temporal_position_encoding": True,
    "temporal_position_encoding_max_len": 32,
}

adapter = MotionAdapter.from_config(config, torch_dtype=torch.float16)

ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())

for key in ckpt_keys:
    old_key = key[:]
    key = key.replace("temporal_transformer.", "")
    if 'attention_blocks' in key:
        loc = key.find('attention_blocks') + len('attention_blocks.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('attention_blocks.', 'attn')
    if 'ff_norm' in key:
        key = key.replace('ff_norm', 'norms.2')
    if 'norms' in key:
        loc = key.find('norms') + len('norms.')
        key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
        key = key.replace('norms.', 'norm')
    if 'pos_encoder' in key:
        key = key.replace('pos_encoder', 'pos_embed')
        replace = ''
        if 'attn1' in key:
            replace = 'attn1.'
        else:
            replace = 'attn2.'
        key = key.replace(replace, '')
    ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)

adapter.load_state_dict(ckpt_state_dict)

# model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16).to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", variant="fp16", torch_dtype=torch.float16).to("cuda")
tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16).to("cuda")
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.020, beta_schedule="scaled_linear")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", torch_dtype=torch.float16)

unet_motion_config = unet.config#.copy()
unet_motion_config["_class_name"] = UNetMotionModel.__name__
unet_motion_config["down_block_types"] = [
    "DownBlockMotion",
    "CrossAttnDownBlockMotion",
    "CrossAttnDownBlockMotion",
]
unet_motion_config["up_block_types"] = [
    "CrossAttnUpBlockMotion",
    "CrossAttnUpBlockMotion",
    "UpBlockMotion",
]
unet_motion_config["mid_block_type"] = "UNetMidBlockCrossAttnMotion"

pipe = AnimateDiffXLPipeline.from_pretrained(
    model_id,
    # torch_dtype=torch.float16,
    # variant="fp16",
    # motion_adapter=adapter,
    vae=vae,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    unet=unet,
    motion_adapter=adapter,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to("cuda")

result = pipe(
    prompt="",
    negative_prompt="",
    num_inference_steps=2,
    width=512,
    height=512,
    num_frames=32,
)
Error log
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 1
----> 1 result = pipe(
      2     prompt="",
      3     negative_prompt="",
      4     num_inference_steps=2,
      5     width=512,
      6     height=512,
      7     num_frames=32,
      8 )

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py:944, in AnimateDiffXLPipeline.__call__(self, prompt, prompt_2, num_frames, height, width, num_inference_steps, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_videos_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, clip_skip, original_size, crops_coords_top_left, target_size)
    941     ts = ts.repeat(2)
    943 # predict the noise residual
--> 944 noise_pred = self.unet(
    945     latent_model_input,
    946     ts,
    947     encoder_hidden_states=prompt_embeds,
    948     cross_attention_kwargs=cross_attention_kwargs,
    949     added_cond_kwargs=added_cond_kwargs,
    950 ).sample
    952 # perform guidance
    953 if do_classifier_free_guidance:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_motion_model.py:854, in UNetMotionModel.forward(self, sample, timestep, encoder_hidden_states, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, return_dict)
    852 for downsample_block in self.down_blocks:
    853     if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 854         sample, res_samples = downsample_block(
    855             hidden_states=sample,
    856             temb=emb,
    857             encoder_hidden_states=encoder_hidden_states,
    858             attention_mask=attention_mask,
    859             num_frames=num_frames,
    860             cross_attention_kwargs=cross_attention_kwargs,
    861         )
    862     else:
    863         sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_3d_blocks.py:1218, in CrossAttnDownBlockMotion.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, num_frames, encoder_attention_mask, cross_attention_kwargs, additional_residuals)
   1216 else:
   1217     hidden_states = resnet(hidden_states, temb, scale=lora_scale)
-> 1218     hidden_states = attn(
   1219         hidden_states,
   1220         encoder_hidden_states=encoder_hidden_states,
   1221         cross_attention_kwargs=cross_attention_kwargs,
   1222         attention_mask=attention_mask,
   1223         encoder_attention_mask=encoder_attention_mask,
   1224         return_dict=False,
   1225     )[0]
   1226     hidden_states = motion_module(
   1227         hidden_states,
   1228         num_frames=num_frames,
   1229     )[0]
   1231 # apply additional residuals to the output of the last pair of resnet and attention blocks

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/diffusers/models/transformer_2d.py:392, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    380         hidden_states = torch.utils.checkpoint.checkpoint(
    381             create_custom_forward(block),
    382             hidden_states,
   (...)
    389             **ckpt_kwargs,
    390         )
    391     else:
--> 392         hidden_states = block(
    393             hidden_states,
    394             attention_mask=attention_mask,
    395             encoder_hidden_states=encoder_hidden_states,
    396             encoder_attention_mask=encoder_attention_mask,
    397             timestep=timestep,
    398             cross_attention_kwargs=cross_attention_kwargs,
    399             class_labels=class_labels,
    400         )
    402 # 3. Output
    403 if self.is_input_continuous:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention.py:323, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
    320     if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
    321         norm_hidden_states = self.pos_embed(norm_hidden_states)
--> 323     attn_output = self.attn2(
    324         norm_hidden_states,
    325         encoder_hidden_states=encoder_hidden_states,
    326         attention_mask=encoder_attention_mask,
    327         **cross_attention_kwargs,
    328     )
    329     hidden_states = attn_output + hidden_states
    331 # 4. Feed-forward

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:527, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    508 r"""
    509 The forward method of the `Attention` class.
    510 
   (...)
    522     `torch.Tensor`: The output of the attention layer.
    523 """
    524 # The `Attention` class can call different attention processors / attention functions
    525 # here we simply pass along all tensors to the selected processor class
    526 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 527 return self.processor(
    528     self,
    529     hidden_states,
    530     encoder_hidden_states=encoder_hidden_states,
    531     attention_mask=attention_mask,
    532     **cross_attention_kwargs,
    533 )

File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1252, in AttnProcessor2_0.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)
   1249 inner_dim = key.shape[-1]
   1250 head_dim = inner_dim // attn.heads
-> 1252 query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
   1254 key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
   1255 value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

RuntimeError: shape '[96, -1, 10, 64]' is invalid for input of size 41943040

I have tried carefully verifying if I'm handling all neccessary parameters correctly and have gotten the entire model structure from AnimateDiff matching with the one here. Here's the diff (although there seem to be many differences, diffusers uses common pos_embeds and such instead of one per attention layer, and those cause most of the differences).

@DN6
Copy link
Collaborator

DN6 commented Dec 22, 2023

Hi @a-r-r-o-w I'll take a look

@a-r-r-o-w
Copy link
Member Author

I'll spend some more time here soon. After reviewing a few PRs that added support for new pipelines which required editing src/diffusers/models, I understand things much better. I think I should reimplement my changes from scratch to find any bugs that I added along the way which is causing the pipeline to not work.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 26, 2024

@DN6 seems like there's a bug with the conversion script for motion modules. fp16 files should be half the file size of fp32 files but that does not look like the case (you could take a look at the official checkpoints here and observe the same). Adding the below fix here:

    adapter.to(dtype=torch.float16)
    adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)

save_pretrained does not seem to make use of torch_dtype and that's what we had before.

@a-r-r-o-w a-r-r-o-w mentioned this pull request Jan 26, 2024
6 tasks
@a-r-r-o-w a-r-r-o-w closed this Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: add SDXL support to animatediff pipeline
4 participants