Skip to content

Commit fd2b820

Browse files
Add noise augmentation to hunyuan image refiner. (#9831)
This was missing and should help with colors being blown out.
1 parent d6b977b commit fd2b820

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,7 @@ class HunyuanImage21Refiner(HunyuanImage21):
14371437
def concat_cond(self, **kwargs):
14381438
noise = kwargs.get("noise", None)
14391439
image = kwargs.get("concat_latent_image", None)
1440+
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
14401441
device = kwargs["device"]
14411442

14421443
if image is None:
@@ -1446,6 +1447,9 @@ def concat_cond(self, **kwargs):
14461447
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
14471448
image = self.process_latent_in(image)
14481449
image = utils.resize_to_batch_size(image, noise.shape[0])
1450+
if noise_augmentation > 0:
1451+
noise = torch.randn(image.shape, generator=torch.manual_seed(kwargs.get("seed", 0) - 10), dtype=image.dtype, device="cpu").to(image.device)
1452+
image = noise_augmentation * noise + (1.0 - noise_augmentation) * image
14491453
return image
14501454

14511455
def extra_conds(self, **kwargs):

comfy_extras/nodes_hunyuan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,18 @@ def INPUT_TYPES(s):
134134
return {"required": {"positive": ("CONDITIONING", ),
135135
"negative": ("CONDITIONING", ),
136136
"latent": ("LATENT", ),
137+
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
137138
}}
138139

139140
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
140141
RETURN_NAMES = ("positive", "negative", "latent")
141142

142143
FUNCTION = "execute"
143144

144-
def execute(self, positive, negative, latent):
145+
def execute(self, positive, negative, latent, noise_augmentation):
145146
latent = latent["samples"]
146-
147-
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent})
148-
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent})
147+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
148+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
149149
out_latent = {}
150150
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
151151
return (positive, negative, out_latent)

0 commit comments

Comments
 (0)