@@ -113,9 +113,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
113
113
description = "FLUX Redux conditioning tensor." ,
114
114
input = Input .Connection ,
115
115
)
116
- uno_reference : FluxUnoReferenceField | None = InputField (
116
+ uno_ref : FluxUnoReferenceField | None = InputField (
117
117
default = None ,
118
- description = "FLUX Redux conditioning tensor ." ,
118
+ description = "FLUX Uno reference ." ,
119
119
input = Input .Connection ,
120
120
)
121
121
fill_conditioning : FluxFillConditioningField | None = InputField (
@@ -293,10 +293,9 @@ def _run_diffusion(
293
293
294
294
img_ids = generate_img_ids (h = latent_h , w = latent_w , batch_size = b , device = x .device , dtype = x .dtype )
295
295
296
- is_flux_uno = self .uno_reference is not None
297
- if is_flux_uno :
296
+ if self .uno_ref is not None :
298
297
# Encode reference images and prepare position ids
299
- uno_ref_imgs = self ._prep_uno_reference_imgs (context )
298
+ uno_ref_imgs = self ._prep_uno_reference_imgs (context = context )
300
299
uno_ref_imgs , uno_ref_ids = prepare_multi_ip (x , uno_ref_imgs )
301
300
else :
302
301
uno_ref_imgs = None
@@ -680,20 +679,20 @@ def _prep_controlnet_extensions(
680
679
681
680
def _prep_uno_reference_imgs (self , context : InvocationContext ) -> list [torch .Tensor ]:
682
681
# Load the conditioning image and resize it to the target image size.
683
- assert self .controlnet_vae is not None , 'Controlnet Vae must be set for UNO encoding'
684
- vae_info = context .models .load (self .controlnet_vae .vae )
685
682
686
- assert self .uno_reference is not None , "Needs reference images for UNO"
683
+ assert self .uno_ref is not None , "uno_ref must be set when using UNO."
684
+ ref_img_names = [i .image_name for i in self .uno_ref .images ]
685
+
686
+ assert self .controlnet_vae is not None , "Controlnet Vae must be set for UNO encoding"
687
+ vae_info = context .models .load (self .controlnet_vae .vae )
687
688
688
- ref_img_names : list [str ] = self .uno_reference .image_names
689
689
ref_latents : list [torch .Tensor ] = []
690
690
691
691
# TODO: Maybe move reference side to UNO Node as parameter
692
692
ref_long_side = 512 if len (ref_img_names ) <= 1 else 320
693
693
694
694
for img_name in ref_img_names :
695
- image_pil = context .images .get_pil (img_name )
696
- image_pil = image_pil .convert ("RGB" ) # To correct resizing
695
+ image_pil = context .images .get_pil (img_name , mode = "RGB" )
697
696
image_pil = preprocess_ref (image_pil , ref_long_side ) # resize and crop
698
697
699
698
image_tensor = (TVF .to_tensor (image_pil ) * 2.0 - 1.0 ).unsqueeze (0 ).float ()
0 commit comments