Skip to content

Commit

Permalink
Support loading controlnets with different input.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 13, 2024
1 parent 6fb44c4 commit cf80d28
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 3 additions & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ def load_controlnet_flux_instantx(sd):
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]

control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4

control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)

latent_format = comfy.latent_formats.Flux()
Expand Down
9 changes: 7 additions & 2 deletions comfy/ldm/flux/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x):


class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)

self.main_model_double = 19
Expand Down Expand Up @@ -80,7 +80,12 @@ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image

self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size

self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
Expand Down

0 comments on commit cf80d28

Please sign in to comment.