-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[WIP] AnimateDiff SDXL #6195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] AnimateDiff SDXL #6195
Conversation
@EdoardoBotta Would you be interested in working on this as well? |
Apologies for the slow progress on this 😅 I was having trouble loading the checkpoint files due to key mismatches in the state dict due to a few implementation differences between the original AnimateDiff codebase and diffusers, which took me a while to understand, but seems like I've gotten it working. To do so, I've added a Codeimport torch
from diffusers.models.unet_motion_model import MotionAdapterXL
config = {
"_class_name": "MotionAdapter",
"_diffusers_version": "0.22.0.dev0",
"block_out_channels": [
320,
640,
1280,
],
"motion_activation_fn": "geglu",
"motion_attention_bias": False,
"motion_cross_attention_dim": None,
"motion_layers_per_block": 2,
"motion_max_seq_length": 24,
"motion_mid_block_layers_per_block": 1,
"motion_norm_num_groups": 32,
"motion_num_attention_heads": 8,
"use_motion_mid_block": False,
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
}
adapter = MotionAdapterXL.from_config(config, torch_dtype=torch.float16)
ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, mmap=True, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())
for key in ckpt_keys:
old_key = key[:]
key = key.replace("temporal_transformer.", "")
if 'attention_blocks' in key:
loc = key.find('attention_blocks') + len('attention_blocks.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('attention_blocks.', 'attn')
if 'ff_norm' in key:
key = key.replace('ff_norm', 'norms.2')
if 'norms' in key:
loc = key.find('norms') + len('norms.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('norms.', 'norm')
if 'pos_encoder' in key:
key = key.replace('pos_encoder', 'pos_embed')
ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)
adapter.load_state_dict(ckpt_state_dict) @sayakpaul I'm not sure how to convert this checkpoint into a diffusers acceptable format if it's already supported. For the time being, what I have above works but let me know what would be a better way to do this. Thanks |
Cc @DN6 here. |
WIP Colab notebook. It doesn't work yet because UNetMotionModel errors out while loading from the unet state dict of the sdxl model. Trying to figure out how to fix it. |
@a-r-r-o-w thanks for your contributions thus far. From #6195 (comment), I thought we were able to load the state dict appropriately or isn't that the case anymore? |
We are indeed able to load the state dict, but for the Motion Adapter. Currently, it fails loading the state dict of the SDXL UNet. Attaching some logs from the notebook. Logs
Relevant Code From UNetMotionModel# ... from_unet2d()
if has_motion_adapter:
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
# Need this for backwards compatibility with UNet2DConditionModel checkpoints
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]
model = cls.from_config(config)
if not load_weights:
return model
model.conv_in.load_state_dict(unet.conv_in.state_dict())
model.time_proj.load_state_dict(unet.time_proj.state_dict())
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
for i, down_block in enumerate(unet.down_blocks):
# fails here
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
if hasattr(model.down_blocks[i], "attentions"):
model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
if model.down_blocks[i].downsamplers:
model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())
for i, up_block in enumerate(unet.up_blocks):
model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
if hasattr(model.up_blocks[i], "attentions"):
model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
if model.up_blocks[i].upsamplers:
model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())
model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) |
I see. Did we add any state dict conversion code for the motion module? |
Nope, I assumed the unet loading would work for SDXL seamlessly but I think we'll have to add a few changes to the UNetMotionModel class as well. |
Made a little progress and have updated the Colab notebook linked above. So far, we can load the motion adapter correctly, UNetMotionModel state dict (albeit very hackily and I still have to clean it up), and manage to make the call to |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @a-r-r-o-w I know this is still WIP. But I left a few comments that should simplify the code a bit. |
Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
After applying changes from the review, the motion adapter can be successfully loaded with this: Codeimport torch
from diffusers.models.unet_motion_model import MotionAdapter
config = {
"_class_name": "MotionAdapter",
"_diffusers_version": "0.22.0.dev0",
"block_out_channels": [
320,
640,
1280,
],
"motion_activation_fn": "geglu",
"motion_attention_bias": False,
"motion_cross_attention_dim": None,
"motion_layers_per_block": 2,
"motion_max_seq_length": 32,
"motion_mid_block_layers_per_block": 1,
"motion_norm_num_groups": 32,
"motion_num_attention_heads": 8,
"use_motion_mid_block": False,
}
adapter = MotionAdapter.from_config(config, torch_dtype=torch.float16)
ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())
for key in ckpt_keys:
old_key = key[:]
key = key.replace("temporal_transformer.", "")
if 'attention_blocks' in key:
loc = key.find('attention_blocks') + len('attention_blocks.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('attention_blocks.', 'attn')
if 'ff_norm' in key:
key = key.replace('ff_norm', 'norms.2')
if 'norms' in key:
loc = key.find('norms') + len('norms.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('norms.', 'norm')
if 'pos_encoder' in key:
key = key.replace('pos_encoder', 'pos_embed')
replace = ''
if 'attn1' in key:
replace = 'attn1.'
else:
replace = 'attn2.'
key = key.replace(replace, '')
ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)
adapter.load_state_dict(ckpt_state_dict) |
Apologies for the delay here... Between university and work, I haven't been able to put in a lot of time yet. @DN6 I'm having trouble getting the inference running. I believe I've got the entire model architecture mapped out correctly and the code for loading the state dicts seems to be working well. However, when executing the inference loop, the batch size in the attention query seems to be incorrect and doesn't even properly divide the input size. Could you help me out? Codeimport torch
from diffusers import AnimateDiffPipeline, AutoencoderKL, DDIMScheduler, EulerDiscreteScheduler
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.models import UNet2DConditionModel, UNet3DConditionModel, UNetMotionModel
from diffusers.pipelines.animatediff.pipeline_animatediff_xl import AnimateDiffXLPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
config = {
"_class_name": "MotionAdapter",
"_diffusers_version": "0.22.0.dev0",
"block_out_channels": [
320,
640,
1280,
],
"motion_activation_fn": "geglu",
"motion_attention_bias": False,
"motion_cross_attention_dim": None,
"motion_layers_per_block": 2,
"motion_max_seq_length": 32,
"motion_mid_block_layers_per_block": 1,
"motion_norm_num_groups": 32,
"motion_num_attention_heads": 8,
"use_motion_mid_block": False,
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
}
adapter = MotionAdapter.from_config(config, torch_dtype=torch.float16)
ckpt_path = "mm_sdxl_v10_beta.ckpt"
ckpt_state_dict = torch.load(ckpt_path, map_location="cpu")
ckpt_keys = list(ckpt_state_dict.keys())
for key in ckpt_keys:
old_key = key[:]
key = key.replace("temporal_transformer.", "")
if 'attention_blocks' in key:
loc = key.find('attention_blocks') + len('attention_blocks.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('attention_blocks.', 'attn')
if 'ff_norm' in key:
key = key.replace('ff_norm', 'norms.2')
if 'norms' in key:
loc = key.find('norms') + len('norms.')
key = key[:loc] + str(ord(key[loc]) - ord('0') + 1) + key[loc + 1:]
key = key.replace('norms.', 'norm')
if 'pos_encoder' in key:
key = key.replace('pos_encoder', 'pos_embed')
replace = ''
if 'attn1' in key:
replace = 'attn1.'
else:
replace = 'attn2.'
key = key.replace(replace, '')
ckpt_state_dict[key] = ckpt_state_dict.pop(old_key)
adapter.load_state_dict(ckpt_state_dict)
# model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16).to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", variant="fp16", torch_dtype=torch.float16).to("cuda")
tokenizer_2 = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16).to("cuda")
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.020, beta_schedule="scaled_linear")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", torch_dtype=torch.float16)
unet_motion_config = unet.config#.copy()
unet_motion_config["_class_name"] = UNetMotionModel.__name__
unet_motion_config["down_block_types"] = [
"DownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
]
unet_motion_config["up_block_types"] = [
"CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion",
"UpBlockMotion",
]
unet_motion_config["mid_block_type"] = "UNetMidBlockCrossAttnMotion"
pipe = AnimateDiffXLPipeline.from_pretrained(
model_id,
# torch_dtype=torch.float16,
# variant="fp16",
# motion_adapter=adapter,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
).to("cuda")
result = pipe(
prompt="",
negative_prompt="",
num_inference_steps=2,
width=512,
height=512,
num_frames=32,
) Error log---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[17], line 1
----> 1 result = pipe(
2 prompt="",
3 negative_prompt="",
4 num_inference_steps=2,
5 width=512,
6 height=512,
7 num_frames=32,
8 )
File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py:944, in AnimateDiffXLPipeline.__call__(self, prompt, prompt_2, num_frames, height, width, num_inference_steps, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_videos_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, clip_skip, original_size, crops_coords_top_left, target_size)
941 ts = ts.repeat(2)
943 # predict the noise residual
--> 944 noise_pred = self.unet(
945 latent_model_input,
946 ts,
947 encoder_hidden_states=prompt_embeds,
948 cross_attention_kwargs=cross_attention_kwargs,
949 added_cond_kwargs=added_cond_kwargs,
950 ).sample
952 # perform guidance
953 if do_classifier_free_guidance:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_motion_model.py:854, in UNetMotionModel.forward(self, sample, timestep, encoder_hidden_states, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, return_dict)
852 for downsample_block in self.down_blocks:
853 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 854 sample, res_samples = downsample_block(
855 hidden_states=sample,
856 temb=emb,
857 encoder_hidden_states=encoder_hidden_states,
858 attention_mask=attention_mask,
859 num_frames=num_frames,
860 cross_attention_kwargs=cross_attention_kwargs,
861 )
862 else:
863 sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_3d_blocks.py:1218, in CrossAttnDownBlockMotion.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, num_frames, encoder_attention_mask, cross_attention_kwargs, additional_residuals)
1216 else:
1217 hidden_states = resnet(hidden_states, temb, scale=lora_scale)
-> 1218 hidden_states = attn(
1219 hidden_states,
1220 encoder_hidden_states=encoder_hidden_states,
1221 cross_attention_kwargs=cross_attention_kwargs,
1222 attention_mask=attention_mask,
1223 encoder_attention_mask=encoder_attention_mask,
1224 return_dict=False,
1225 )[0]
1226 hidden_states = motion_module(
1227 hidden_states,
1228 num_frames=num_frames,
1229 )[0]
1231 # apply additional residuals to the output of the last pair of resnet and attention blocks
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /usr/local/lib/python3.10/dist-packages/diffusers/models/transformer_2d.py:392, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
380 hidden_states = torch.utils.checkpoint.checkpoint(
381 create_custom_forward(block),
382 hidden_states,
(...)
389 **ckpt_kwargs,
390 )
391 else:
--> 392 hidden_states = block(
393 hidden_states,
394 attention_mask=attention_mask,
395 encoder_hidden_states=encoder_hidden_states,
396 encoder_attention_mask=encoder_attention_mask,
397 timestep=timestep,
398 cross_attention_kwargs=cross_attention_kwargs,
399 class_labels=class_labels,
400 )
402 # 3. Output
403 if self.is_input_continuous:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention.py:323, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
320 if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
321 norm_hidden_states = self.pos_embed(norm_hidden_states)
--> 323 attn_output = self.attn2(
324 norm_hidden_states,
325 encoder_hidden_states=encoder_hidden_states,
326 attention_mask=encoder_attention_mask,
327 **cross_attention_kwargs,
328 )
329 hidden_states = attn_output + hidden_states
331 # 4. Feed-forward
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:527, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
508 r"""
509 The forward method of the `Attention` class.
510
(...)
522 `torch.Tensor`: The output of the attention layer.
523 """
524 # The `Attention` class can call different attention processors / attention functions
525 # here we simply pass along all tensors to the selected processor class
526 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 527 return self.processor(
528 self,
529 hidden_states,
530 encoder_hidden_states=encoder_hidden_states,
531 attention_mask=attention_mask,
532 **cross_attention_kwargs,
533 )
File /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1252, in AttnProcessor2_0.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)
1249 inner_dim = key.shape[-1]
1250 head_dim = inner_dim // attn.heads
-> 1252 query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1254 key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1255 value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
RuntimeError: shape '[96, -1, 10, 64]' is invalid for input of size 41943040 I have tried carefully verifying if I'm handling all neccessary parameters correctly and have gotten the entire model structure from AnimateDiff matching with the one here. Here's the diff (although there seem to be many differences, diffusers uses common pos_embeds and such instead of one per attention layer, and those cause most of the differences). |
Hi @a-r-r-o-w I'll take a look |
I'll spend some more time here soon. After reviewing a few PRs that added support for new pipelines which required editing |
@DN6 seems like there's a bug with the conversion script for motion modules. fp16 files should be half the file size of fp32 files but that does not look like the case (you could take a look at the official checkpoints here and observe the same). Adding the below fix here: adapter.to(dtype=torch.float16)
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16) save_pretrained does not seem to make use of torch_dtype and that's what we had before. |
What does this PR do?
Attempt at integrating https://github.com/guoyww/AnimateDiff/tree/sdxl.
Relevant discussion: #5928 (comment)
Fixes #6158.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul, @patrickvonplaten, @DN6