Skip to content

Commit 4e5a9ef

Browse files
authored
update conversion script for SANA-1.5 and SANA-Sprint (#11082)
* 1. update conversion script for sana1.5; 2. add conversion script for sana-sprint; * seperate sana and sana-sprint conversion scripts; * update for upstream * fix the } bug * add a doc for SanaSprintPipeline; * minor update; * make style && make quality
1 parent 398ca0c commit 4e5a9ef

File tree

2 files changed

+155
-16
lines changed

2 files changed

+155
-16
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# SanaSprintPipeline
16+
17+
<div class="flex flex-wrap space-x-1">
18+
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
19+
</div>
20+
21+
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA and MIT HAN Lab, by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
22+
23+
The abstract from the paper is:
24+
25+
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
26+
27+
<Tip>
28+
29+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
30+
31+
</Tip>
32+
33+
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
34+
35+
Available models:
36+
37+
| Model | Recommended dtype |
38+
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
39+
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
40+
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
41+
42+
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
43+
44+
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
45+
46+
47+
## Quantization
48+
49+
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
50+
51+
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
52+
53+
```py
54+
import torch
55+
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
56+
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
57+
58+
quant_config = BitsAndBytesConfig(load_in_8bit=True)
59+
text_encoder_8bit = AutoModel.from_pretrained(
60+
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
61+
subfolder="text_encoder",
62+
quantization_config=quant_config,
63+
torch_dtype=torch.bfloat16,
64+
)
65+
66+
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
67+
transformer_8bit = SanaTransformer2DModel.from_pretrained(
68+
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
69+
subfolder="transformer",
70+
quantization_config=quant_config,
71+
torch_dtype=torch.bfloat16,
72+
)
73+
74+
pipeline = SanaSprintPipeline.from_pretrained(
75+
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
76+
text_encoder=text_encoder_8bit,
77+
transformer=transformer_8bit,
78+
torch_dtype=torch.bfloat16,
79+
device_map="balanced",
80+
)
81+
82+
prompt = "a tiny astronaut hatching from an egg on the moon"
83+
image = pipeline(prompt).images[0]
84+
image.save("sana.png")
85+
```
86+
87+
## SanaSprintPipeline
88+
89+
[[autodoc]] SanaSprintPipeline
90+
- all
91+
- __call__
92+
93+
94+
## SanaPipelineOutput
95+
96+
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput

scripts/convert_sana_to_diffusers.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
CTX = init_empty_weights if is_accelerate_available else nullcontext
2828

2929
ckpt_ids = [
30+
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
3031
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
3132
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
3233
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -75,7 +76,8 @@ def main(args):
7576
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
7677

7778
# Handle different time embedding structure based on model type
78-
if args.model_type == "SanaSprint_1600M_P1_D20":
79+
80+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
7981
# For Sana Sprint, the time embedding structure is different
8082
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
8183
"t_embedder.mlp.0.weight"
@@ -128,10 +130,18 @@ def main(args):
128130
layer_num = 20
129131
elif args.model_type == "SanaMS_600M_P1_D28":
130132
layer_num = 28
133+
elif args.model_type == "SanaMS_4800M_P1_D60":
134+
layer_num = 60
131135
else:
132136
raise ValueError(f"{args.model_type} is not supported.")
133137
# Positional embedding interpolation scale.
134138
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
139+
qk_norm = "rms_norm_across_heads" if args.model_type in [
140+
"SanaMS1.5_1600M_P1_D20",
141+
"SanaMS1.5_4800M_P1_D60",
142+
"SanaSprint_600M_P1_D28",
143+
"SanaSprint_1600M_P1_D20"
144+
] else None
135145

136146
for depth in range(layer_num):
137147
# Transformer blocks.
@@ -145,6 +155,14 @@ def main(args):
145155
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
146156
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
147157
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
158+
if qk_norm is not None:
159+
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
160+
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
161+
f"blocks.{depth}.attn.q_norm.weight"
162+
)
163+
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
164+
f"blocks.{depth}.attn.k_norm.weight"
165+
)
148166
# Projection.
149167
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
150168
f"blocks.{depth}.attn.proj.weight"
@@ -191,6 +209,14 @@ def main(args):
191209
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
192210
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
193211
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
212+
if qk_norm is not None:
213+
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
214+
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
215+
f"blocks.{depth}.cross_attn.q_norm.weight"
216+
)
217+
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
218+
f"blocks.{depth}.cross_attn.k_norm.weight"
219+
)
194220

195221
# Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint
196222
if args.model_type == "SanaSprint_1600M_P1_D20":
@@ -235,8 +261,7 @@ def main(args):
235261
}
236262

237263
# Add qk_norm parameter for Sana Sprint
238-
if args.model_type == "SanaSprint_1600M_P1_D20":
239-
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
264+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
240265
transformer_kwargs["guidance_embeds"] = True
241266

242267
transformer = SanaTransformer2DModel(**transformer_kwargs)
@@ -271,23 +296,24 @@ def main(args):
271296
)
272297
)
273298
transformer.save_pretrained(
274-
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
299+
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
275300
)
276301
else:
277302
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
278303
# VAE
279-
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
304+
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
280305

281306
# Text Encoder
282-
text_encoder_model_path = "google/gemma-2-2b-it"
307+
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
283308
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
284309
tokenizer.padding_side = "right"
285310
text_encoder = AutoModelForCausalLM.from_pretrained(
286311
text_encoder_model_path, torch_dtype=torch.bfloat16
287312
).get_decoder()
288313

289314
# Choose the appropriate pipeline and scheduler based on model type
290-
if args.model_type == "SanaSprint_1600M_P1_D20":
315+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
316+
291317
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
292318
if args.scheduler_type != "scm":
293319
print(
@@ -335,7 +361,7 @@ def main(args):
335361
scheduler=scheduler,
336362
)
337363

338-
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
364+
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
339365

340366

341367
DTYPE_MAPPING = {
@@ -344,12 +370,6 @@ def main(args):
344370
"bf16": torch.bfloat16,
345371
}
346372

347-
VARIANT_MAPPING = {
348-
"fp32": None,
349-
"fp16": "fp16",
350-
"bf16": "bf16",
351-
}
352-
353373

354374
if __name__ == "__main__":
355375
parser = argparse.ArgumentParser()
@@ -369,7 +389,7 @@ def main(args):
369389
"--model_type",
370390
default="SanaMS_1600M_P1_D20",
371391
type=str,
372-
choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaSprint_1600M_P1_D20"],
392+
choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaMS_4800M_P1_D60", "SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"],
373393
)
374394
parser.add_argument(
375395
"--scheduler_type",
@@ -400,6 +420,30 @@ def main(args):
400420
"cross_attention_head_dim": 72,
401421
"cross_attention_dim": 1152,
402422
"num_layers": 28,
423+
},
424+
"SanaMS1.5_1600M_P1_D20": {
425+
"num_attention_heads": 70,
426+
"attention_head_dim": 32,
427+
"num_cross_attention_heads": 20,
428+
"cross_attention_head_dim": 112,
429+
"cross_attention_dim": 2240,
430+
"num_layers": 20,
431+
},
432+
"SanaMS1.5__4800M_P1_D60": {
433+
"num_attention_heads": 70,
434+
"attention_head_dim": 32,
435+
"num_cross_attention_heads": 20,
436+
"cross_attention_head_dim": 112,
437+
"cross_attention_dim": 2240,
438+
"num_layers": 60,
439+
},
440+
"SanaSprint_600M_P1_D28": {
441+
"num_attention_heads": 36,
442+
"attention_head_dim": 32,
443+
"num_cross_attention_heads": 16,
444+
"cross_attention_head_dim": 72,
445+
"cross_attention_dim": 1152,
446+
"num_layers": 28,
403447
},
404448
"SanaSprint_1600M_P1_D20": {
405449
"num_attention_heads": 70,
@@ -413,6 +457,5 @@ def main(args):
413457

414458
device = "cuda" if torch.cuda.is_available() else "cpu"
415459
weight_dtype = DTYPE_MAPPING[args.dtype]
416-
variant = VARIANT_MAPPING[args.dtype]
417460

418461
main(args)

0 commit comments

Comments
 (0)