Skip to content

[WIP] AnimateDiff SDXL #6195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add missing upcast_vae function to pipeline
  • Loading branch information
a-r-r-o-w committed Dec 18, 2023
commit 7c751d8a830ca3ce511db7267d4c3a4ce787a0d2
30 changes: 28 additions & 2 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...schedulers import (
Expand Down Expand Up @@ -124,8 +130,8 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
Expand Down Expand Up @@ -628,6 +634,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)

@torch.no_grad()
def __call__(
self,
Expand Down