Add cross attention type for Sana-Sprint training in diffusers.#11514
Add cross attention type for Sana-Sprint training in diffusers.#11514sayakpaul merged 12 commits intohuggingface:mainfrom
Conversation
| elif cross_attention_type == "vanilla": | ||
| cross_attention_processor = SanaAttnProcessor3_0() |
There was a problem hiding this comment.
Can't we modify the SanaAttnProcessor2_0() class to handle the changes of SanaAttnProcessor3_0?
There was a problem hiding this comment.
If we merge 2_0 and 3_0, we then need a variance to check when to use the function here:
which will be similar with
cross_attention_type: str = "flash",
| guidance_embeds_scale: float = 0.1, | ||
| qk_norm: Optional[str] = None, | ||
| timestep_scale: float = 1.0, | ||
| cross_attention_type: str = "flash", |
There was a problem hiding this comment.
This goes a bit against our design.
There was a problem hiding this comment.
Then can we just separate it into two classes and let u to help for better implementation?
There was a problem hiding this comment.
Actually, the only difference is that F.scaled_dot_product_attention is not supported by torch.JVP. Therefore, during training we need to replace with the vanilla attention implementation. Any good idea how to merge these two? @sayakpaul
There was a problem hiding this comment.
Ah I see. If that is the case, I think we should through the attention processor mechanism wherein, we use something like set_attn_processor and use the vanilla attention processor class.
If this is only needed for training, I think we should have the following methods added to the model class:
We can then just include the vanilla attention processor implementation in the training utility and do something like
model = SanaTransformer2DModel(...)
model.set_attn_processor(SanaVanillaAttnProcessor())WDYT? @DN6 any suggestion?
There was a problem hiding this comment.
Oh this is cool and nusty IMO, thanks. I'll change the code.
| @@ -0,0 +1,1656 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
This is perfect! This is 100 percent the way to go here. We can include the attention processor here in a file (attention_processor.py) and use it from there in the training script.
Based on https://github.com/huggingface/diffusers/pull/11514/files#r2077921763.
Since we're using a folder for the training script, I won't mind if we want move out the dataloader into a separate script, utilities in a separate script. But completely up to you.
There was a problem hiding this comment.
I don't mind it. Could you help for this one? :)
There was a problem hiding this comment.
Yes, after the https://github.com/huggingface/diffusers/pull/11514/files#r2077921763 comments are addressed, I will help with that
…SanaAttnProcessor3_0` to `SanaVanillaAttnProcessor`
|
I have changed the code as recommended here: https://github.com/huggingface/diffusers/pull/11514/files#r2077921763. I hope it's what you mean. @sayakpaul |
|
Tested locally after adding SanaVanillaAttnProcessor imports — the changes work as expected. LGTM! @lawrence-cj @sayakpaul |
|
|
||
| huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers | ||
|
|
||
| python train_sana_sprint_diffusers.py \ |
There was a problem hiding this comment.
This is perfect!
Do we want to move the dataset class into a separate file dataset.py? I am okay if we want to do that since it's already under research_projects.
Also, let's add a readme with instructions on how to acquire the dataset, etc. Currently, we're only using three shards I think.
There was a problem hiding this comment.
I agree with u. Please help for this separated script!
@scxue Help for the readme part pls!
There was a problem hiding this comment.
README updated! @sayakpaul – Feel free to share any feedback or suggestions.
sayakpaul
left a comment
There was a problem hiding this comment.
Looking very nice. Some minor comments and we should be able to merge soon.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Has anyone successfully run this? I'm encountering various runtime errors, especially when using the log_validation function. The |
|
Thanks for pointing it out! It looks like the log_validation function has some buggy implementations. I will fix it with @lawrence-cj @sayakpaul. |
|
@PeiqinSun mistakes can happen and there are ways to point it out. This is point of making things openly available. |
|
Thanks to @scxue and @sayakpaul for the prompt response. I look forward to your fixes. |
Add cross attention type for Sana-Sprint training in diffusers. @sayakpaul