Skip to content

[ControlNet] Adds controlnet for SanaTransformer #11040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 13, 2025

Conversation

ishan-modi
Copy link
Contributor

@ishan-modi ishan-modi commented Mar 12, 2025

What does this PR do?

Fixes #10772, #11019, #11116

Who can review?

@yiyixuxu

@ishan-modi ishan-modi marked this pull request as ready for review March 13, 2025 15:34
@ishan-modi ishan-modi changed the title [WIP] Adds ControlNet for SanaTransformer [ControlNet] Adds controlnet for SanaTransformer Mar 13, 2025
@ishan-modi ishan-modi mentioned this pull request Mar 17, 2025
@a-r-r-o-w
Copy link
Member

@ishan-modi Sorry about the slow review here. The team is at an offsite for this week and taking a break, but we'll try to merge asap once we're back next week. Thanks for the awesome work

@ishan-modi
Copy link
Contributor Author

All good man ! enjoy your offsite.

@lawrence-cj
Copy link
Contributor

gentle ping @a-r-r-o-w

@a-r-r-o-w
Copy link
Member

Hi, sorry for the delay. Testing now and hopefully can merge soon 🤗

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the awesome work! Looks very close to merge except for a few more changes. LMK if I can help with any 🤗

timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)

# controlnet(s) inference
controlnet_block_samples = self.controlnet(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inference example in the docstring errors out for me here. This is because ControlNet is loaded in bf16, but latent_model_input is moved to self.transformer.dtype, which is fp16, due to the following code in encode_prompt:

        if self.transformer is not None:
            dtype = self.transformer.dtype
        elif self.text_encoder is not None:
            dtype = self.text_encoder.dtype
        else:
            dtype = None

        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Let's do this:

  • Remove the above code occurence for determining dtype in encode_prompt
  • Return the prompt_embeds in the same dtype as the text encoder
  • Perform any dtype casting within the __call__ method based on the controlnet_dtype and transformer_dtype (create these variables similar to how it's done in Wan

Also, the following are the dtypes of each component:

  • vae: bf16
  • text_encoder: bf16
  • transformer: fp16
  • controlnet: bf16

Just so that I'm up to speed, is this expected?

Copy link
Contributor Author

@ishan-modi ishan-modi Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point me to the docstring that loads controlnet in bf16 ?

I think I overlooked the following:

  • controlnet for SANA_600M is supposed to use fp16 here
  • controlnet for SANA_1600M is supposed to use bf16 here

Most of the doc loads controlnet into fp16, but I guess it needs to be more generic as you mentioned

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to the example code that's at the top of the file:

import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image

controlnet = SanaControlNetModel.from_pretrained(
    "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
    controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
    "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
    prompt,
    control_image=cond_image,
).images[0]
image.save("output.png")

Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think:

  1. I think we should update the docstring example to use fp16 for both contronet and transformer, unless there is a special reason to do it this way, i.e controlnet in bf16 while transformer in fp16
  2. the changes @a-r-r-o-w proposed here [ControlNet] Adds controlnet for SanaTransformer #11040 (comment) sounds good. I think all the sana pipelines should have same encode_prompt methods, no? if so, let's not remove the #Copied from in encode_prompt, and update the one in pipeline_sana.py and make sure changes applied to all sana pipelines

Comment on lines 370 to 375
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_encoder cannot be None. transformer can be None since we should be able to run encode_prompt without loading the transformer.

Let's make sure this method returns embeds in the same dtype as text encoder and do casting in __call__

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I think we need to make sure prompt_embeds is not None code path can work without text_encoders loaded too, no?

in modular, we started to make it a way so that you only run encode_prompt when you actually need to encode prompt, that's not the case here yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w let me know if I should change the current version to

if self.text_encoder is not None:
    dtype = self.text_encoder.dtype
else:
    dtype = None

Once confirm I will make similar changes to pipeline_sana.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, based on YiYi's comment, let's do this for both pipelines

>>> from diffusers.utils import load_image

>>> controlnet = SanaControlNetModel.from_pretrained(
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lawrence-cj Could we host the controlnet checkpoint in the Efficient-Large-Model org? We generally don't merge without officially hosted weights unless it's necessary for a quick release (which we then update later anyway).

@ishan-modi Please feel free to mention your hosted controlnet model in the docs 🤗

Copy link
Contributor

@lawrence-cj lawrence-cj Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah, I would like to do it and test the PR at the same time. @a-r-r-o-w

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you @lawrence-cj! I'll run some final tests and merge the PR in a few hours

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lawrence-cj The checkpoint does not seem accessible yet and I get a 404. I'll go ahead and merge this PR for now, and we can update the docs/examples with the official checkpoint in a follow up

timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)

# controlnet(s) inference
controlnet_block_samples = self.controlnet(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to the example code that's at the top of the file:

import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image

controlnet = SanaControlNetModel.from_pretrained(
    "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
    controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
    "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
    prompt,
    control_image=cond_image,
).images[0]
image.save("output.png")

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM! Going to try testing out the model again and we can merge once we have the official hosted checkpoint by Junsong 🤗

@a-r-r-o-w
Copy link
Member

LMK if you need any help with the failing tests. They seem to be because of a misplace .to() on controlnet outputs, and due to SanaControlNetPipelineOutput being used instead of SanaControlNetOutput in the documentation

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @ishan-modi, thanks a lot!

Just some final changes

@a-r-r-o-w a-r-r-o-w merged commit f1f38ff into huggingface:main Apr 13, 2025
12 checks passed
@lawrence-cj
Copy link
Contributor

Congrats! Thank you so much @ishan-modi

@ishan-modi ishan-modi deleted the fixes-issue-10772 branch April 19, 2025 05:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Sana Controlnet Support
6 participants