✉Corresponding Author.
🔥 Huggingface demo for Ghibli style generation supported by EasyControl.
⚡️ Huggingface demo now supports text-to-image generation with SD3 and SD3.5.
💰 Bonus tip: You can even use pure zero-init (zeroing out the prediction of the first step) as a quick test—if it improves your flow-matching model a lot, it may indicate that the model has not converged yet.
🧪 Usage Tip: Use both optimized-scale and zero-init together. Adjust the zero-init steps based on total inference steps — 4% is generally a good starting point.
- [2025.4.14] HiDream is suppported now!
- [2025.4.14] 🔥 Supported by sdnext now!
- [2025.4.6] 📙 Supported by EasyControl now!
- [2025.4.4] 🤗 Supported by Diffusers now!
- [2025.4.2] 🙌 Mentioned by Wan2.1!
- [2025.4.1] Qwen2.5-Omni is suppported now!
- [2025.3.30] Hunyuan is officially supported now!
- [2025.3.29] Flux is officially supported now!
- [2025.3.29] Both Wan2.1-14B I2V & T2V are now supported!
- [2025.3.28] Wan2.1-14B T2V is now supported! (Note: The default setting has been updated to zero out 4% of total steps for this scenario.)
- [2025.3.27] 📙 Supported by ComfyUI-KJNodes now!
- [2025.03.26] 📙 Supported by Wan2.1GP now!
- [2025.03.25] Paper|Demo|Code have been officially released.
If you find that CFG-Zero* helps improve your model, we'd love to hear about it!
Thanks to the following models for supporting our method!
- SD.Next
- EasyControl
- ComfyUI-KJNodes
- Wan2.1GP
- ComfyUI Noted that ComfyUI's implementation is different from ours.
- Wan2.1
- 14B Text-to-Video
- 14B Image-to-Video
- Hunyuan
- Text-to-Video
- SD3/SD3.5
- Text-to-Image
- Flux
- Text-to-Image (Guidance-distilled version)
- Lora
- CogView4
- Text-to-Image
- Qwen2.5-Omni
- Audio generation
- EasyControl
- Ghibli-Style Portrait Generation
- HiDream
- text2image pipeline
Note: You may want to adjust the CUDA version according to your driver version.
conda create -n CFG_Zero_Star python=3.10
conda activate CFG_Zero_Star
#Install pytorch according to your cuda version
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
pip install -r requirements.txt
apt install -y ffmpeg
Host a demo on your local machine.
python demo.py
Simply run the following command to generate videos in the output folder. Noted that the current version is using Wan-AI/Wan2.1-T2V-14B-Diffusers with the default setting.
Noted that zero-steps for wan2.1 T2V is set to 1 (first 2 steps, 4% of the total steps).
python models/wan/T2V_infer.py
All results shown below were generated using this script on an H100 80G GPU.
Follow Wan2.1 to clone the repo and finish the installation, then copy 'models/wan/image2video_cfg_zero_star.py' in this repo to the Wan2.1 repo (Wan2.1/wan). Modify 'Wan2.1/wan/init.py': replace 'from .image2video import WanI2V' with 'from .image2video_cfg_zero_star import WanI2V'.
Note: For I2V, zero_init_steps is set to 0 [2.5% zero out] by default to ensure stable generation. If you prefer more creative results, you can set it to 1 [5% zero out], though this may lead to instability in certain cases.
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 0
All results shown below were generated using this script on an H100 80G GPU.
We used black-forest-labs/FLUX.1-dev for the following experiment. Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. All images below were generated with the same seed on an H100 80G GPU.
python models/flux/Guidance_distilled.py
![]() |
![]() |
CFG | CFG-Zero* |
Prompt: "a tiny astronaut hatching from an egg on the moon." Seed: 105297965 |
We used black-forest-labs/FLUX.1-dev with different loras for the following experiment. Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. All images below were generated with the same seed on an H100 80G GPU.
python models/flux/infer_lora.py
![]() |
![]() |
CFG | CFG-Zero* |
Prompt: "Death Stranding Style. A solitary figure in a futuristic suit with a large, intricate backpack stands on a grassy cliff, gazing at a vast, mist-covered landscape composed of rugged mountains and low valleys beneath a rainy, overcast sky. Raindrops streak softly through the air, and puddles glisten on the uneven ground. Above the horizon, an ethereal, upside-down rainbow arcs downward through the gray clouds — its surreal, inverted shape adding an otherworldly touch to the haunting scene. A soft glow from distant structures illuminates the depth of the valley, enhancing the mysterious atmosphere. The contrast between the rain-soaked greenery and jagged rocky terrain adds texture and detail, amplifying the sense of solitude, exploration, and the anticipation of unknown adventures beyond the horizon." Seed: 875187112 Lora: https://civitai.com/models/46080/death-stranding |
We used hunyuanvideo-community/HunyuanVideo for the following experiment. Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. All images below were generated with the same seed on an H100 80G GPU.
python models/hunyuan/t2v.py
![]() |
![]() |
CFG | CFG-Zero* |
Prompt: "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. A man is surfing, cinematic film shot in 35mm. High quality, high defination." Seed: 376559893 |
We used stabilityai/stable-diffusion-3.5-large for the following experiment. All images below were generated with the same seed on an H100 80G GPU.
python models/sd/infer.py
![]() |
![]() |
CFG | CFG-Zero* |
Prompt: "A capybara holding a sign that reads Hello World" Seed: 811677707 |
Install dependencies for Qwen2.5-Omni
pip install git+https://github.com/huggingface/transformers@f742a644ca32e65758c3adb36225aef1731bd2a8
pip install qwen-omni-utils[decord]
pip install flash-attn --no-build-isolation
Easy inference with CFG-Zero*
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python models/Qwen2.5/infer.py
The following audios are generated by the script:
CFG | CFG-Zero⋆ |
---|---|
🔊 Click to download | 🔊 Click to download |
Ghibli-Style Portrait Generation, the zero-init steps is set to 1 for default, feel free to try other values.
python models/easycontrol/infer.py
![]() |
![]() |
![]() |
Source Image | CFG | CFG-Zero* |
python models/Cogview4/infer.py
Git clone HiDream, and replace hidream_pipeline with ours 'models/HiDream/pipeline.py'
Then modify 'zero_steps' according the total inference steps.
cd HiDream-I1
python ./inference.py --model_type full
![]() |
![]() |
CFG | CFG-Zero* |
Prompt: "A cat holding a sign that says \"Hi-Dreams.ai\"." Seed: 0 |
You can use this script to easily apply our method to any flow-matching-based model.
def optimized_scale(positive_flat, negative_flat):
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_condˆT * v_uncond / ||v_uncond||ˆ2
st_star = dot_product / squared_norm
return st_star
# Get the velocity prediction
noise_pred_uncond, noise_pred_text = model(...)
positive = noise_pred_text.view(Batchsize,-1)
negative = noise_pred_uncond.view(Batchsize,-1)
# Calculate the optimized scale
st_star = optimized_scale(positive,negative)
# Reshape for broadcasting
st_star = st_star.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1)))
# Perform CFG-Zero* sampling
if sample_step == 0:
# Perform zero init
noise_pred = noise_pred_uncond * 0.
else:
# Perform optimized scale
noise_pred = noise_pred_uncond * st_star + \
guidance_scale * (noise_pred_text - noise_pred_uncond * st_star)
@misc{fan2025cfgzerostar,
title={CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models},
author={Weichen Fan and Amber Yijia Zheng and Raymond A. Yeh and Ziwei Liu},
year={2025},
eprint={2503.18886},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2503.18886},
}
This code is licensed under Apache-2.0. The framework is fully open for academic research and also allows any commercial usage.
We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards.