Skip to content

Commit

Permalink
Cleanup HunyuanDit controlnets.
Browse files Browse the repository at this point in the history
Use the: ControlNetApply SD3 and HunyuanDiT node.
  • Loading branch information
comfyanonymous committed Aug 9, 2024
1 parent 06eb9fb commit a475ec2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 194 deletions.
145 changes: 42 additions & 103 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""


import torch
from enum import Enum
import math
import os
import logging
Expand Down Expand Up @@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)

class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2

class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
Expand All @@ -51,6 +75,8 @@ def __init__(self, device=None):
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT

def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint
Expand Down Expand Up @@ -93,6 +119,8 @@ def copy_to(self, c):
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type

def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
Expand All @@ -113,7 +141,10 @@ def control_merge(self, control, control_prev, output_dtype):

if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
x *= self.strength
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))

if x.dtype != output_dtype:
x = x.to(output_dtype)
Expand Down Expand Up @@ -142,7 +173,7 @@ def set_extra_arg(self, argument, value=None):


class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
Expand All @@ -154,6 +185,8 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand Down Expand Up @@ -192,7 +225,7 @@ def get_control(self, x_noisy, t, cond, batched_number):

context = cond.get('crossattn_controlnet', cond['c_crossattn'])
extra = self.extra_args.copy()
for c in ["y", "guidance"]: #TODO
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
Expand Down Expand Up @@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

class ControlNetWarperHunyuanDiT(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)

dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

def get_tensor(name):
if name in cond:
if isinstance(cond[name], torch.Tensor):
return cond[name].to(dtype)
else:
return cond[name]
else:
return None

encoder_hidden_states = get_tensor('c_crossattn')
text_embedding_mask = get_tensor('text_embedding_mask')
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
image_meta_size = get_tensor('image_meta_size')
style = get_tensor('style')
cos_cis_img = get_tensor('cos_cis_img')
sin_cis_img = get_tensor('sin_cis_img')

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(
x=x_noisy.to(dtype),
t=timestep.float(),
condition=self.cond_hint,
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
**self.extra_args
)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c

def load_controlnet_hunyuandit(controlnet_data):

supported_inference_dtypes = [torch.float16, torch.float32]

unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
missing, unexpected = control_model.load_state_dict(controlnet_data)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
control_model = controlnet_load_state_dict(control_model, controlnet_data)

latent_format = comfy.latent_formats.SDXL()
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP)
return control

def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)

if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)

Expand Down
53 changes: 13 additions & 40 deletions comfy/ldm/hydit/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,11 @@
from .poolers import AttentionPool

import comfy.latent_formats
from .models import HunYuanDiTBlock
from .models import HunYuanDiTBlock, calc_rope

from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop


def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module


def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope


class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Expand Down Expand Up @@ -213,35 +196,32 @@ def __init__(
)

# Input zero linear for the first block
self.before_proj = zero_module(
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
)
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)


# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
zero_module(
nn.Linear(

operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
)
for _ in range(len(self.blocks))
]
)

def forward(
self,
x: torch.Tensor,
t: torch.Tensor = None,
condition=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
control_weight=1.0,
transformer_options=None,
return_dict=False,
**kwarg,
):
"""
Expand Down Expand Up @@ -270,10 +250,11 @@ def forward(
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)

text_states = encoder_hidden_states # 2,77,1024
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
Expand Down Expand Up @@ -304,7 +285,7 @@ def forward(
) # (cos_cis_img, sin_cis_img)

# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=self.dtype)
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)

# ========================= Concatenate all extra vectors =========================
Expand Down Expand Up @@ -337,12 +318,4 @@ def forward(
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output

control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
assert len(control_weights) == len(
controls
), "control_weights and controls should have the same length"
controls = [
control * weight for control, weight in zip(controls, control_weights)
]

return {"output": controls}
51 changes: 0 additions & 51 deletions comfy_extras/nodes_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,58 +19,7 @@ def encode(self, clip, bert, mt5xl):
cond = output.pop("cond")
return ([[cond, output]], )


class ControlNetApplyAdvancedHunYuan:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"vae": ("VAE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}

RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"

CATEGORY = "conditioning/controlnet"

def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None):
if strength == 0:
return (positive, negative)

control_hint = image.movedim(-1,1)
cnets = {}

out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()

prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net.set_extra_arg('control_weight', control_weight)

c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net

d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1])

NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan,
}
5 changes: 5 additions & 0 deletions comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ def INPUT_TYPES(s):
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
}

NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
}

2 comments on commit a475ec2

@YuanGYao
Copy link

@YuanGYao YuanGYao commented on a475ec2 Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit caused SDXL ControlNet to not work properly. Check out to the previous commit will make ControlNet work properly.

The console reported the following error message:

!!! Exception during processing!!! 'NoneType' object has no attribute 'shape'
Traceback (most recent call last):
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\execution.py", line 152, in recursive_execute
    output_data, output_ui = get_output_data(obj, input_data_all)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\execution.py", line 82, in get_output_data
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\execution.py", line 75, in map_node_over_list
    results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\nodes.py", line 1382, in sample
    return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\nodes.py", line 1352, in common_ksampler
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\sample.py", line 43, in sample
    samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 829, in sample
    return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 729, in sample
    return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 716, in sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 695, in inner_sample
    samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 600, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\k_diffusion\sampling.py", line 600, in sample_dpmpp_2m
    denoised = model(x, sigmas[i] * s_in, **extra_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 299, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 682, in __call__
    return self.predict_noise(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 685, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 279, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 202, in calc_cond_batch
    c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\controlnet.py", line 236, in get_control
    control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\sd\gui\ComfyUI_windows_portable\ComfyUI\comfy\cldm\cldm.py", line 420, in forward
    assert y.shape[0] == x.shape[0]
           ^^^^^^^
AttributeError: 'NoneType' object has no attribute 'shape'

@Michoko92
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second that, since this commit, I get the same errors as comment above

Please sign in to comment.