-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[ControlNet] Adds controlnet for SanaTransformer #11040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
eef6240
to
6a62c3e
Compare
6c71ca6
to
d698d81
Compare
fc00d13
to
7f3cbc5
Compare
@ishan-modi Sorry about the slow review here. The team is at an offsite for this week and taking a break, but we'll try to merge asap once we're back next week. Thanks for the awesome work |
All good man ! enjoy your offsite. |
gentle ping @a-r-r-o-w |
Hi, sorry for the delay. Testing now and hopefully can merge soon 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the awesome work! Looks very close to merge except for a few more changes. LMK if I can help with any 🤗
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) | ||
|
||
# controlnet(s) inference | ||
controlnet_block_samples = self.controlnet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inference example in the docstring errors out for me here. This is because ControlNet is loaded in bf16, but latent_model_input is moved to self.transformer.dtype
, which is fp16, due to the following code in encode_prompt:
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Let's do this:
- Remove the above code occurence for determining dtype in encode_prompt
- Return the prompt_embeds in the same dtype as the text encoder
- Perform any dtype casting within the
__call__
method based on thecontrolnet_dtype
andtransformer_dtype
(create these variables similar to how it's done in Wan
Also, the following are the dtypes of each component:
- vae: bf16
- text_encoder: bf16
- transformer: fp16
- controlnet: bf16
Just so that I'm up to speed, is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you point me to the docstring that loads controlnet in bf16 ?
I think I overlooked the following:
- controlnet for SANA_600M is supposed to use fp16 here
- controlnet for SANA_1600M is supposed to use bf16 here
Most of the doc loads controlnet into fp16, but I guess it needs to be more generic as you mentioned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm referring to the example code that's at the top of the file:
import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
variant="fp16",
torch_dtype=torch.float16,
controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
prompt,
control_image=cond_image,
).images[0]
image.save("output.png")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think:
- I think we should update the docstring example to use fp16 for both contronet and transformer, unless there is a special reason to do it this way, i.e controlnet in bf16 while transformer in fp16
- the changes @a-r-r-o-w proposed here [ControlNet] Adds controlnet for SanaTransformer #11040 (comment) sounds good. I think all the sana pipelines should have same encode_prompt methods, no? if so, let's not remove the
#Copied from
inencode_prompt
, and update the one in pipeline_sana.py and make sure changes applied to all sana pipelines
if self.transformer is not None: | ||
dtype = self.transformer.dtype | ||
elif self.text_encoder is not None: | ||
dtype = self.text_encoder.dtype | ||
else: | ||
dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text_encoder cannot be None. transformer can be None since we should be able to run encode_prompt
without loading the transformer.
Let's make sure this method returns embeds in the same dtype as text encoder and do casting in __call__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh I think we need to make sure prompt_embeds is not None
code path can work without text_encoders loaded too, no?
in modular, we started to make it a way so that you only run encode_prompt when you actually need to encode prompt, that's not the case here yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w let me know if I should change the current version to
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
Once confirm I will make similar changes to pipeline_sana.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, based on YiYi's comment, let's do this for both pipelines
>>> from diffusers.utils import load_image | ||
|
||
>>> controlnet = SanaControlNetModel.from_pretrained( | ||
... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lawrence-cj Could we host the controlnet checkpoint in the Efficient-Large-Model
org? We generally don't merge without officially hosted weights unless it's necessary for a quick release (which we then update later anyway).
@ishan-modi Please feel free to mention your hosted controlnet model in the docs 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yah, I would like to do it and test the PR at the same time. @a-r-r-o-w
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w We hosted the official ckpt here: https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_ControlNet_diffusers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you @lawrence-cj! I'll run some final tests and merge the PR in a few hours
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lawrence-cj The checkpoint does not seem accessible yet and I get a 404. I'll go ahead and merge this PR for now, and we can update the docs/examples with the official checkpoint in a follow up
ff92747
to
3d085a2
Compare
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) | ||
|
||
# controlnet(s) inference | ||
controlnet_block_samples = self.controlnet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm referring to the example code that's at the top of the file:
import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
variant="fp16",
torch_dtype=torch.float16,
controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
prompt,
control_image=cond_image,
).images[0]
image.save("output.png")
09359e6
to
dea5de5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM! Going to try testing out the model again and we can merge once we have the official hosted checkpoint by Junsong 🤗
LMK if you need any help with the failing tests. They seem to be because of a misplace |
fbe517b
to
b973cd0
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @ishan-modi, thanks a lot!
Just some final changes
Congrats! Thank you so much @ishan-modi |
What does this PR do?
Fixes #10772, #11019, #11116
Who can review?
@yiyixuxu