Skip to content

feat&fix(diffusers): add QwenImage lora finetune, new models and pipes and fix bugs#1394

Merged
vigo999 merged 21 commits intomindspore-lab:masterfrom
Dong1017:qwenimage_update
Dec 24, 2025
Merged

feat&fix(diffusers): add QwenImage lora finetune, new models and pipes and fix bugs#1394
vigo999 merged 21 commits intomindspore-lab:masterfrom
Dong1017:qwenimage_update

Conversation

@Dong1017
Copy link
Contributor

@Dong1017 Dong1017 commented Oct 27, 2025

What does this PR do?

Adds

  1. LoRA finetune script for QwenImage with dataset from lambdalabs/pokemon-blip-captions. For dreambooth, the script is coming soon.
  2. New pipeline QwenImageEditPlusPipeline, targeting feat: Add QwenImageEditPlus to support future feature upgrades
  3. New models QwenImageControlNetModel and QwenImageMultiControlNetModel and pipelines QwenImageControlNetPipeline and QwenImageControlNetInpaintPipeline, targeting Support ControlNet for Qwen-Image and Support ControlNet-Inpainting for Qwen-Image
  4. New modular models and pipelines QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditPlusAutoBlocks, QwenImageModularPipeline, QwenImageEditModularPipeline, QwenImageEditPlusModularPipeline, targeting [Modular] Qwen, [modular] add tests for qwen modular and [core] support QwenImage Edit Plus in modular

Fixes according to diffusers merged PRs

  1. [LoRA] feat: support more Qwen LoRAs from the community.
  2. [docs] Clarify guidance scale in Qwen pipelines
  3. [chore] add lora button to qwenimage docs
  4. Emergency fix for Qwen-Image-Edit
  5. Performance Improve for Qwen Image Edit
  6. add attentionmixin to qwen image
  7. [Qwen-Image] adding validation for guidance_scale, true_cfg_scale and negative_prompt
  8. [QwenImageEditPipeline] Add image entry in call function
  9. Fix lora conversion function for ai-toolkit Qwen Image LoRAs
  10. [refactor] Make guiders return their inputs

Usage

  1. Finetune
export ASCEND_RT_VISIBLE_DEVICES=0,1
NPUS=2
MASTER_PORT=9000
LOG_DIR=outputs/lora
msrun --bind_core=True --worker_num=${NPUS} --local_worker_num=${NPUS} --master_port=${MASTER_PORT} --log_dir=${LOG_DIR}/parallel_logs \
python finetune_lora_with_mindspore_trainer.py \
    --output_dir ${LOG_DIR} \
    --num_train_epochs 1 \
    --learning_rate 1e-5 \
    --save_strategy no \
    --bf16
  1. Edit-Plus
import mindspore as ms 
from PIL import Image 
from mindone.diffusers import QwenImageEditPlusPipeline
from mindone.diffusers.utils import load_image 

pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen-Image-Edit-2509", dtype=ms.bfloat16) 
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png").convert("RGB") 
prompt = ("Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors")
# Depending on the variant being used, the pipeline call will slightly vary. 
# Refer to the pipeline documentation for more details. 
image = pipe(image, prompt, num_inference_steps=50)[0][0] 
image.save("qwenimage_edit_plus.png") 
  1. Controlnet (multi-NPUs)
#!/bin/bash
export ASCEND_RT_VISIBLE_DEVICES=0,1

# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NPROC_PER_NODE=${WORLD_SIZE:-2}

entry_file="controlnet_script.py"

msrun --worker_num=${NPROC_PER_NODE} \
    --local_worker_num=${NPROC_PER_NODE} \
    --master_addr=${MASTER_ADDR} \
    --master_port=${MASTER_PORT} \
    --log_dir="logs/qwenimage_control" \
    --join=True \
    ${entry_file} \

More details about the controlnet_script.py have been attached below.

- Single control image
from functools import partial

import mindspore as ms

import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm
from mindone.trainers.zero import prepare_network

from mindone.diffusers.utils import load_image
from mindone.diffusers import QwenImageControlNetModel, QwenImageControlNetPipeline

dist.init_process_group()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
local_rank = dist.get_rank()

ms_dtype = ms.bfloat16
control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png")
prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
negative_prompt = " "

# single control image
controlnet = QwenImageControlNetModel.from_pretrained(
    "InstantX/Qwen-Image-ControlNet-Union",
    dtype=ms_dtype,
)
pipe = QwenImageControlNetPipeline.from_pretrained(
    "Qwen/Qwen-Image", 
    controlnet=controlnet, 
    dtype=ms_dtype
)

shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
pipe.transformer = shard_fn(pipe.transformer)

dist.barrier()

image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    control_image=control_image,
    controlnet_conditioning_scale=1.0,
    num_inference_steps=30,
    true_cfg_scale=4.0
)[0][0]

if local_rank == 0:
    image.save("qwenimage_cn_union.png")
- Multiple control images
from functools import partial

import mindspore as ms

import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm
from mindone.trainers.zero import prepare_network

from mindone.diffusers.utils import load_image
from mindone.diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline

dist.init_process_group()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
local_rank = dist.get_rank()

ms_dtype = ms.bfloat16
control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png")
prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
negative_prompt = " "

# multiple control images
controlnet = QwenImageControlNetModel.from_pretrained(
    "InstantX/Qwen-Image-ControlNet-Union",
    dtype=ms_dtype,
)
controlnet = QwenImageMultiControlNetModel([controlnet])
pipe = QwenImageControlNetPipeline.from_pretrained(
    "Qwen/Qwen-Image", 
    controlnet=controlnet, 
    dtype=ms_dtype
)

shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
pipe.transformer = shard_fn(pipe.transformer)

dist.barrier()

image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    control_image=[control_image, control_image],
    controlnet_conditioning_scale=[0.5, 0.5],
    num_inference_steps=30,
    true_cfg_scale=4.0,
)[0][0]

if local_rank == 0:
    image.save("qwenimage_cn_union_multi.png")
- Control inpaint
from functools import partial

import mindspore as ms

import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm
from mindone.trainers.zero import prepare_network

from mindone.diffusers.utils import load_image
from mindone.diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline

dist.init_process_group()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
local_rank = dist.get_rank()

ms_dtype = ms.bfloat16
control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png")
control_mask = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png")
prompt = "一辆绿色的出租车行驶在路上"

controlnet = QwenImageControlNetModel.from_pretrained(
    "InstantX/Qwen-Image-ControlNet-Inpainting",
    dtype=ms_dtype,
)
pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
    "Qwen/Qwen-Image", 
    controlnet=controlnet, 
    dtype=ms_dtype
)

shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
pipe.transformer = shard_fn(pipe.transformer)

dist.barrier()

image = pipe(
    prompt=prompt,
    control_image=control_image,
    control_mask=control_mask,
    controlnet_conditioning_scale=1.0,
    width=control_mask.size[0],
    height=control_mask.size[1],
    true_cfg_scale=4.0,
)[0][0]

if local_rank == 0:
    image.save("qwenimage_cn_control_inpaint.png")
  1. Modular
#!/bin/bash
export ASCEND_RT_VISIBLE_DEVICES=0,1

# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NPROC_PER_NODE=${WORLD_SIZE:-2}

entry_file="test_modular.py"

msrun --worker_num=${NPROC_PER_NODE} \
    --local_worker_num=${NPROC_PER_NODE} \
    --master_addr=${MASTER_ADDR} \
    --master_port=${MASTER_PORT} \
    --log_dir="logs/modular" \
    --join=True \
    ${entry_file} \

More details about the test_modular.py have been attached below.

- QwenImageModularPipeline
from functools import partial

import mindspore as ms
from mindone.diffusers.modular_pipelines import SequentialPipelineBlocks
from mindone.diffusers.modular_pipelines.qwenimage import TEXT2IMAGE_BLOCKS

import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm
from mindone.trainers.zero import prepare_network

dist.init_process_group()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
local_rank = dist.get_rank()

blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)

modular_repo_id = "YiYiXu/QwenImage-modular"
pipeline = blocks.init_pipeline(modular_repo_id)

pipeline.load_default_components(dtype=ms.bfloat16)

shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
pipe.transformer = shard_fn(pipe.transformer)

dist.barrier()

image = pipeline(
    prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 
    output="images"
)[0]

if local_rank == 0:
    image.save("modular_t2i.png")

Performances

(Inference experiments are tested on Ascend Atlas 800T A2 machines with MindSpore 2.7.1. Finetune experiments are tested on Ascend Atlas 800T A2 machines with MindSpore 2.7.0)

  1. Finetune
Type Mode Traniable ratio Speed for one step in an epoch (s/it)
Finetune pynative 0.0577% 67.34
  1. Edit-Plus
Pipeline Weight Loading Time Mode Speed (s/it)
QwenImageEditPlusPipeline 87.94s pynative / jit 8.41 / 8.57
  1. Controlnet (Two NPUs)
Pipeline Weight Loading Time Mode Speed (s/it)
QwenImageControlNetPipeline 58.86s pynative / jit 9.36 / 9.17
QwenImageControlNetInpaintPipeline 58.32s pynative / jit 6.34 / 6.28
  1. Modular (Two NPUs)
Pipeline Weight Loading Time Mode Speed (s/it)
QwenImageModularPipeline 57.78s pynative / jit 8.31 / 8.98

Limitations

  1. MindSpore does not support gradient computation for complex numbers (but Torch does). Fortunately, the function of rotary embeddings contains non-learnable parameters and does not require gradients; we can bypass this issue using with ms._no_grad().
  2. The function def compute_loss in mindone/transformers/trainer.py has not yet been implemented, so it is not possible to directly specify a loss function for automatic calculation. Instead, we manually define the computation in the script by constructing a class TrainStepForQwenImage.
  3. The function def merge_and_unload in mindone/peft/tuners/lora/model.py is used to save the fine-tuned weights. However, in practice, it does not automatically merge the zero3 split weights. To address this, we use ops.AllGather() in the fine-tuning script to combine the split weights and save them as a complete set of weights.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@Dong1017 Dong1017 requested a review from vigo999 as a code owner October 27, 2025 07:46
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Dong1017, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Qwen-Image model's training capabilities by introducing a dedicated LoRA fine-tuning script. It addresses several underlying issues in LoRA weight conversion and rotary embedding application, ensuring smoother and more robust training workflows. Additionally, the integration of gradient checkpointing and DeepSpeed Zero3 support aims to improve memory efficiency and scalability for distributed training environments.

Highlights

  • QwenImage LoRA Finetuning Script: A new script has been added to enable LoRA fine-tuning for the Qwen-Image model using the MindSpore Trainer, supporting both single-device and multi-card distributed training.
  • LoRA Conversion Improvements: Bug fixes and enhancements have been made to the LoRA conversion utility to correctly handle various naming conventions (e.g., 'diffusion_model.' prefix) and existing Diffusers-format LoRA weights during loading.
  • Gradient Checkpointing for QwenImage Transformer: Gradient checkpointing has been implemented for the QwenImageTransformer2DModel to optimize memory usage during training, especially for larger models.
  • Rotary Embedding Training Support: The rotary embedding application in the QwenImage transformer has been adjusted to properly support training by wrapping the relevant operations in a ms._no_grad() context.
  • DeepSpeed Zero3 Configuration: A zero3.json configuration file has been added to facilitate DeepSpeed Zero3 optimization, enabling more efficient memory management for large-scale distributed training.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a LoRA fine-tuning script for the Qwen-Image model and includes several bug fixes to support this functionality. The new example script is comprehensive, but contains some hardcoded values that could be parameterized for better flexibility. The fixes in the LoRA loading utilities and the Qwen-Image transformer model are valuable, particularly the addition of gradient checkpointing support, which is crucial for training large models. Overall, this is a solid contribution that enhances the usability of the Qwen-Image model.

encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep / 1000,
img_shapes=[(1, 32, 32)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The img_shapes argument is hardcoded to [(1, 32, 32)]. This is likely correct for the default 512x512 input size, but it will be incorrect if the height or width arguments are changed. This should be calculated dynamically based on the input dimensions. The latent shape is typically image_dim // vae_scale_factor // patch_size, which is image_dim // 16 in this case.

Suggested change
img_shapes=[(1, 32, 32)],
img_shapes=[(1, self.args.height // 16, self.args.width // 16)],

Comment on lines 139 to 140
train_indices = list(range(666))
eval_indices = list(range(666, 833))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The indices for the train/eval split are hardcoded. This limits the script's flexibility. It would be better to define these as variables or derive them from a split ratio to make it easier to adapt for different datasets or experiments.

eval_indices = list(range(666, 833))

def process_function(examples):
image = Image.open(io.BytesIO(examples["image"]["bytes"])).convert("RGB").resize((512, 512))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The image resize dimensions (512, 512) are hardcoded. It's better to use data_args.height and data_args.width to allow for easy configuration of the image size.

Suggested change
image = Image.open(io.BytesIO(examples["image"]["bytes"])).convert("RGB").resize((512, 512))
image = Image.open(io.BytesIO(examples["image"]["bytes"])).convert("RGB").resize((data_args.width, data_args.height))

height=height,
width=width,
dtype=encoder_hidden_states.dtype,
generator=np.random.Generator(np.random.PCG64(seed=42)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The seed for the random number generator is hardcoded to 42. For better reproducibility and control, consider using the main script's seed from args.seed.

Suggested change
generator=np.random.Generator(np.random.PCG64(seed=42)),
generator=np.random.Generator(np.random.PCG64(seed=args.seed)),

Copy link
Collaborator

@Fzilan Fzilan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add a readme file of the lora finetune?

@Dong1017 Dong1017 changed the title feat&fix(diffusers): add QwenImage lora finetune and fix bugs feat&fix(diffusers): add QwenImage lora finetune, new pipes and fix bugs Oct 28, 2025
@Dong1017 Dong1017 changed the title feat&fix(diffusers): add QwenImage lora finetune, new pipes and fix bugs feat&fix(diffusers): add QwenImage lora finetune, new models and pipes and fix bugs Oct 29, 2025
@vigo999 vigo999 added this pull request to the merge queue Dec 24, 2025
Merged via the queue into mindspore-lab:master with commit eab5745 Dec 24, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants