Skip to content

Commit fd57da0

Browse files
feat(nodes): tidy uno reference impl
1 parent 7140e8e commit fd57da0

File tree

3 files changed

+15
-30
lines changed

3 files changed

+15
-30
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,7 @@ class FluxReduxConditioningField(BaseModel):
283283
class FluxUnoReferenceField(BaseModel):
284284
"""A FLUX Uno image list primitive value"""
285285

286-
image_names: list[str] | None = Field(
287-
default=None,
288-
description="The name of the image associated with this conditioning tensor. This is used to store the image "
289-
"in the context.",
290-
)
286+
images: list[ImageField] = Field(description="The images to use as reference for FLUX Uno.")
291287

292288

293289
class FluxFillConditioningField(BaseModel):

invokeai/app/invocations/flux_denoise.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
113113
description="FLUX Redux conditioning tensor.",
114114
input=Input.Connection,
115115
)
116-
uno_reference: FluxUnoReferenceField | None = InputField(
116+
uno_ref: FluxUnoReferenceField | None = InputField(
117117
default=None,
118-
description="FLUX Redux conditioning tensor.",
118+
description="FLUX Uno reference.",
119119
input=Input.Connection,
120120
)
121121
fill_conditioning: FluxFillConditioningField | None = InputField(
@@ -293,10 +293,9 @@ def _run_diffusion(
293293

294294
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
295295

296-
is_flux_uno = self.uno_reference is not None
297-
if is_flux_uno:
296+
if self.uno_ref is not None:
298297
# Encode reference images and prepare position ids
299-
uno_ref_imgs = self._prep_uno_reference_imgs(context)
298+
uno_ref_imgs = self._prep_uno_reference_imgs(context=context)
300299
uno_ref_imgs, uno_ref_ids = prepare_multi_ip(x, uno_ref_imgs)
301300
else:
302301
uno_ref_imgs = None
@@ -680,20 +679,20 @@ def _prep_controlnet_extensions(
680679

681680
def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]:
682681
# Load the conditioning image and resize it to the target image size.
683-
assert self.controlnet_vae is not None, 'Controlnet Vae must be set for UNO encoding'
684-
vae_info = context.models.load(self.controlnet_vae.vae)
685682

686-
assert self.uno_reference is not None, "Needs reference images for UNO"
683+
assert self.uno_ref is not None, "uno_ref must be set when using UNO."
684+
ref_img_names = [i.image_name for i in self.uno_ref.images]
685+
686+
assert self.controlnet_vae is not None, "Controlnet Vae must be set for UNO encoding"
687+
vae_info = context.models.load(self.controlnet_vae.vae)
687688

688-
ref_img_names: list[str] = self.uno_reference.image_names
689689
ref_latents: list[torch.Tensor] = []
690690

691691
# TODO: Maybe move reference side to UNO Node as parameter
692692
ref_long_side = 512 if len(ref_img_names) <= 1 else 320
693693

694694
for img_name in ref_img_names:
695-
image_pil = context.images.get_pil(img_name)
696-
image_pil = image_pil.convert("RGB") # To correct resizing
695+
image_pil = context.images.get_pil(img_name, mode="RGB")
697696
image_pil = preprocess_ref(image_pil, ref_long_side) # resize and crop
698697

699698
image_tensor = (TVF.to_tensor(image_pil) * 2.0 - 1.0).unsqueeze(0).float()

invokeai/app/invocations/flux_uno.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
from PIL import Image
42

53
from invokeai.app.invocations.baseinvocation import (
@@ -52,7 +50,7 @@ def preprocess_ref(raw_image: Image.Image, long_size: int = 512) -> Image.Image:
5250
class FluxUnoOutput(BaseInvocationOutput):
5351
"""The conditioning output of a FLUX Redux invocation."""
5452

55-
uno_refs: FluxUnoReferenceField = OutputField(description="Reference images container", title="Reference images")
53+
uno_ref: FluxUnoReferenceField = OutputField(description="Reference images container", title="Reference images")
5654

5755

5856
@invocation(
@@ -66,16 +64,8 @@ class FluxUnoOutput(BaseInvocationOutput):
6664
class FluxUnoInvocation(BaseInvocation):
6765
"""Loads a FLUX UNO reference images."""
6866

69-
image: ImageField = InputField(description="The UNO reference image.")
70-
image2: Optional[ImageField] = InputField(default=None, description="2nd reference")
71-
image3: Optional[ImageField] = InputField(default=None, description="3rd reference")
72-
image4: Optional[ImageField] = InputField(default=None, description="4th reference")
67+
images: list[ImageField] | None = InputField(default=None, description="The UNO reference images.")
7368

7469
def invoke(self, context: InvocationContext) -> FluxUnoOutput:
75-
images: list[str] = []
76-
for image in [self.image, self.image2, self.image3, self.image4]:
77-
if image is not None:
78-
image_pil = context.images.get_pil(image.image_name)
79-
images.append(context.images.save(image=image_pil).image_name)
80-
81-
return FluxUnoOutput(uno_refs=FluxUnoReferenceField(image_names=images))
70+
uno_ref = FluxUnoReferenceField(images=self.images or [])
71+
return FluxUnoOutput(uno_ref=uno_ref)

0 commit comments

Comments
 (0)