Skip to content

ZImageControlNetPipeline does not support guidance_scale>1 #13073

@christopher5106

Description

@christopher5106

Describe the bug

Shape mismatch issue when true CFG is enabled with controlnet:

  File "tt.py", line 48, in <module>
    image = pipe(
            ^^^^^
  File "venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.11/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py", line 646, in __call__
    controlnet_block_samples = self.controlnet(
                               ^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.11/site-packages/diffusers/models/controlnets/controlnet_z_image.py", line 705, in forward
    control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: The shape of the mask [13440] at index 0 does not match the shape of the indexed tensor [6720, 3840] at index 0

Reproduction

Take the example at the beginning of the file, and set guidance_scale to any value above 1 for example:

import torch
from diffusers import ZImageControlNetPipeline
from diffusers import ZImageControlNetModel
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download

controlnet = ZImageControlNetModel.from_single_file(
    hf_hub_download(
        "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union",
        filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors",
    ),
    torch_dtype=torch.bfloat16,
)

# 2.1
# controlnet = ZImageControlNetModel.from_single_file(
#     hf_hub_download(
#         "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
#         filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors",
#     ),
#     torch_dtype=torch.bfloat16,
# )

# 2.0
# controlnet = ZImageControlNetModel.from_single_file(
#     hf_hub_download(
#         "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
#         filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
#     ),
#     torch_dtype=torch.bfloat16,
# )

pipe = ZImageControlNetPipeline.from_pretrained(
    "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.to("cuda")

# Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
# (1) Use flash attention 2
# pipe.transformer.set_attention_backend("flash")
# (2) Use flash attention 3
# pipe.transformer.set_attention_backend("_flash_3")

control_image = load_image(
    "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true"
)
prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
image = pipe(
    prompt,
    control_image=control_image,
    controlnet_conditioning_scale=0.75,
    height=1728,
    width=992,
    num_inference_steps=9,
    guidance_scale=2.0,
    generator=torch.Generator("cuda").manual_seed(43),
).images[0]
image.save("zimage.png")

Logs

System Info

No need here

Who can help?

@yiyixuxu @sayakpaul @DN6 @asomoza

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions