Skip to content

Commit cbc2ec8

Browse files
authored
AnimateDiff prompt travel (#9231)
* update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * undo memory optimization changes * update docstrings * fix tests * fix pia tests * apply suggestions from review * add tests * update comment
1 parent b5f591f commit cbc2ec8

15 files changed

+469
-119
lines changed

src/diffusers/models/attention.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,15 +972,32 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
972972
return frame_indices
973973

974974
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975-
if weighting_scheme == "pyramid":
975+
if weighting_scheme == "flat":
976+
weights = [1.0] * num_frames
977+
978+
elif weighting_scheme == "pyramid":
976979
if num_frames % 2 == 0:
977980
# num_frames = 4 => [1, 2, 2, 1]
978-
weights = list(range(1, num_frames // 2 + 1))
981+
mid = num_frames // 2
982+
weights = list(range(1, mid + 1))
979983
weights = weights + weights[::-1]
980984
else:
981985
# num_frames = 5 => [1, 2, 3, 2, 1]
982-
weights = list(range(1, num_frames // 2 + 1))
983-
weights = weights + [num_frames // 2 + 1] + weights[::-1]
986+
mid = (num_frames + 1) // 2
987+
weights = list(range(1, mid))
988+
weights = weights + [mid] + weights[::-1]
989+
990+
elif weighting_scheme == "delayed_reverse_sawtooth":
991+
if num_frames % 2 == 0:
992+
# num_frames = 4 => [0.01, 2, 2, 1]
993+
mid = num_frames // 2
994+
weights = [0.01] * (mid - 1) + [mid]
995+
weights = weights + list(range(mid, 0, -1))
996+
else:
997+
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
998+
mid = (num_frames + 1) // 2
999+
weights = [0.01] * mid
1000+
weights = weights + list(range(mid, 0, -1))
9841001
else:
9851002
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
9861003

src/diffusers/models/controlnet_sparsectrl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,6 @@ def forward(
691691

692692
emb = self.time_embedding(t_emb, timestep_cond)
693693
emb = emb.repeat_interleave(sample_num_frames, dim=0)
694-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
695694

696695
# 2. pre-process
697696
batch_size, channels, num_frames, height, width = sample.shape

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116

117117
self.in_channels = in_channels
118118

119-
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
119+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
120120
self.proj_in = nn.Linear(in_channels, inner_dim)
121121

122122
# 3. Define transformers blocks
@@ -2178,7 +2178,6 @@ def forward(
21782178

21792179
emb = emb if aug_emb is None else emb + aug_emb
21802180
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
2181-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
21822181

21832182
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
21842183
if "image_embeds" not in added_cond_kwargs:

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
432432
extra_step_kwargs["generator"] = generator
433433
return extra_step_kwargs
434434

435-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
436435
def check_inputs(
437436
self,
438437
prompt,
@@ -470,8 +469,8 @@ def check_inputs(
470469
raise ValueError(
471470
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472471
)
473-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
472+
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
473+
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
475474

476475
if negative_prompt is not None and negative_prompt_embeds is not None:
477476
raise ValueError(
@@ -557,11 +556,15 @@ def cross_attention_kwargs(self):
557556
def num_timesteps(self):
558557
return self._num_timesteps
559558

559+
@property
560+
def interrupt(self):
561+
return self._interrupt
562+
560563
@torch.no_grad()
561564
@replace_example_docstring(EXAMPLE_DOC_STRING)
562565
def __call__(
563566
self,
564-
prompt: Union[str, List[str]] = None,
567+
prompt: Optional[Union[str, List[str]]] = None,
565568
num_frames: Optional[int] = 16,
566569
height: Optional[int] = None,
567570
width: Optional[int] = None,
@@ -701,9 +704,10 @@ def __call__(
701704
self._guidance_scale = guidance_scale
702705
self._clip_skip = clip_skip
703706
self._cross_attention_kwargs = cross_attention_kwargs
707+
self._interrupt = False
704708

705709
# 2. Define call parameters
706-
if prompt is not None and isinstance(prompt, str):
710+
if prompt is not None and isinstance(prompt, (str, dict)):
707711
batch_size = 1
708712
elif prompt is not None and isinstance(prompt, list):
709713
batch_size = len(prompt)
@@ -716,22 +720,39 @@ def __call__(
716720
text_encoder_lora_scale = (
717721
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
718722
)
719-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
720-
prompt,
721-
device,
722-
num_videos_per_prompt,
723-
self.do_classifier_free_guidance,
724-
negative_prompt,
725-
prompt_embeds=prompt_embeds,
726-
negative_prompt_embeds=negative_prompt_embeds,
727-
lora_scale=text_encoder_lora_scale,
728-
clip_skip=self.clip_skip,
729-
)
730-
# For classifier free guidance, we need to do two forward passes.
731-
# Here we concatenate the unconditional and text embeddings into a single batch
732-
# to avoid doing two forward passes
733-
if self.do_classifier_free_guidance:
734-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
723+
if self.free_noise_enabled:
724+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
725+
prompt=prompt,
726+
num_frames=num_frames,
727+
device=device,
728+
num_videos_per_prompt=num_videos_per_prompt,
729+
do_classifier_free_guidance=self.do_classifier_free_guidance,
730+
negative_prompt=negative_prompt,
731+
prompt_embeds=prompt_embeds,
732+
negative_prompt_embeds=negative_prompt_embeds,
733+
lora_scale=text_encoder_lora_scale,
734+
clip_skip=self.clip_skip,
735+
)
736+
else:
737+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
738+
prompt,
739+
device,
740+
num_videos_per_prompt,
741+
self.do_classifier_free_guidance,
742+
negative_prompt,
743+
prompt_embeds=prompt_embeds,
744+
negative_prompt_embeds=negative_prompt_embeds,
745+
lora_scale=text_encoder_lora_scale,
746+
clip_skip=self.clip_skip,
747+
)
748+
749+
# For classifier free guidance, we need to do two forward passes.
750+
# Here we concatenate the unconditional and text embeddings into a single batch
751+
# to avoid doing two forward passes
752+
if self.do_classifier_free_guidance:
753+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
754+
755+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
735756

736757
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
737758
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -783,6 +804,9 @@ def __call__(
783804
# 8. Denoising loop
784805
with self.progress_bar(total=self._num_timesteps) as progress_bar:
785806
for i, t in enumerate(timesteps):
807+
if self.interrupt:
808+
continue
809+
786810
# expand the latents if we are doing classifier free guidance
787811
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
788812
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,8 @@ def check_inputs(
505505
raise ValueError(
506506
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507507
)
508-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
509-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
508+
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
509+
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
510510

511511
if negative_prompt is not None and negative_prompt_embeds is not None:
512512
raise ValueError(
@@ -699,6 +699,10 @@ def cross_attention_kwargs(self):
699699
def num_timesteps(self):
700700
return self._num_timesteps
701701

702+
@property
703+
def interrupt(self):
704+
return self._interrupt
705+
702706
@torch.no_grad()
703707
def __call__(
704708
self,
@@ -858,9 +862,10 @@ def __call__(
858862
self._guidance_scale = guidance_scale
859863
self._clip_skip = clip_skip
860864
self._cross_attention_kwargs = cross_attention_kwargs
865+
self._interrupt = False
861866

862867
# 2. Define call parameters
863-
if prompt is not None and isinstance(prompt, str):
868+
if prompt is not None and isinstance(prompt, (str, dict)):
864869
batch_size = 1
865870
elif prompt is not None and isinstance(prompt, list):
866871
batch_size = len(prompt)
@@ -883,22 +888,39 @@ def __call__(
883888
text_encoder_lora_scale = (
884889
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
885890
)
886-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
887-
prompt,
888-
device,
889-
num_videos_per_prompt,
890-
self.do_classifier_free_guidance,
891-
negative_prompt,
892-
prompt_embeds=prompt_embeds,
893-
negative_prompt_embeds=negative_prompt_embeds,
894-
lora_scale=text_encoder_lora_scale,
895-
clip_skip=self.clip_skip,
896-
)
897-
# For classifier free guidance, we need to do two forward passes.
898-
# Here we concatenate the unconditional and text embeddings into a single batch
899-
# to avoid doing two forward passes
900-
if self.do_classifier_free_guidance:
901-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
891+
if self.free_noise_enabled:
892+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
893+
prompt=prompt,
894+
num_frames=num_frames,
895+
device=device,
896+
num_videos_per_prompt=num_videos_per_prompt,
897+
do_classifier_free_guidance=self.do_classifier_free_guidance,
898+
negative_prompt=negative_prompt,
899+
prompt_embeds=prompt_embeds,
900+
negative_prompt_embeds=negative_prompt_embeds,
901+
lora_scale=text_encoder_lora_scale,
902+
clip_skip=self.clip_skip,
903+
)
904+
else:
905+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
906+
prompt,
907+
device,
908+
num_videos_per_prompt,
909+
self.do_classifier_free_guidance,
910+
negative_prompt,
911+
prompt_embeds=prompt_embeds,
912+
negative_prompt_embeds=negative_prompt_embeds,
913+
lora_scale=text_encoder_lora_scale,
914+
clip_skip=self.clip_skip,
915+
)
916+
917+
# For classifier free guidance, we need to do two forward passes.
918+
# Here we concatenate the unconditional and text embeddings into a single batch
919+
# to avoid doing two forward passes
920+
if self.do_classifier_free_guidance:
921+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
922+
923+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
902924

903925
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
904926
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -990,6 +1012,9 @@ def __call__(
9901012
# 8. Denoising loop
9911013
with self.progress_bar(total=self._num_timesteps) as progress_bar:
9921014
for i, t in enumerate(timesteps):
1015+
if self.interrupt:
1016+
continue
1017+
9931018
# expand the latents if we are doing classifier free guidance
9941019
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
9951020
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1002,7 +1027,6 @@ def __call__(
10021027
else:
10031028
control_model_input = latent_model_input
10041029
controlnet_prompt_embeds = prompt_embeds
1005-
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
10061030

10071031
if isinstance(controlnet_keep[i], list):
10081032
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]

src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ def __call__(
11431143
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
11441144
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
11451145

1146+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
1147+
11461148
prompt_embeds = prompt_embeds.to(device)
11471149
add_text_embeds = add_text_embeds.to(device)
11481150
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)

src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,8 @@ def __call__(
878878
if self.do_classifier_free_guidance:
879879
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
880880

881+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
882+
881883
# 4. Prepare IP-Adapter embeddings
882884
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
883885
image_embeds = self.prepare_ip_adapter_image_embeds(

0 commit comments

Comments
 (0)