Skip to content

Commit 0515511

Browse files
KosinkadinkEricBCoding
authored andcommitted
[Reviving comfyanonymous#5709] Add strength input to Differential Diffusion (comfyanonymous#9957)
* Update nodes_differential_diffusion.py * Update nodes_differential_diffusion.py * Make strength optional to avoid validation errors when loading old workflows, adjust step --------- Co-authored-by: ThereforeGames <eric@sparknight.io>
1 parent c730faf commit 0515511

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

comfy_extras/nodes_differential_diffusion.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,30 @@
55
class DifferentialDiffusion():
66
@classmethod
77
def INPUT_TYPES(s):
8-
return {"required": {"model": ("MODEL", ),
9-
}}
8+
return {
9+
"required": {
10+
"model": ("MODEL", ),
11+
},
12+
"optional": {
13+
"strength": ("FLOAT", {
14+
"default": 1.0,
15+
"min": 0.0,
16+
"max": 1.0,
17+
"step": 0.01,
18+
}),
19+
}
20+
}
1021
RETURN_TYPES = ("MODEL",)
1122
FUNCTION = "apply"
1223
CATEGORY = "_for_testing"
1324
INIT = False
1425

15-
def apply(self, model):
26+
def apply(self, model, strength=1.0):
1627
model = model.clone()
17-
model.set_model_denoise_mask_function(self.forward)
18-
return (model,)
28+
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
29+
return (model, )
1930

20-
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
31+
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
2132
model = extra_options["model"]
2233
step_sigmas = extra_options["sigmas"]
2334
sigma_to = model.inner_model.model_sampling.sigma_min
@@ -31,7 +42,15 @@ def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options
3142

3243
threshold = (current_ts - ts_to) / (ts_from - ts_to)
3344

34-
return (denoise_mask >= threshold).to(denoise_mask.dtype)
45+
# Generate the binary mask based on the threshold
46+
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
47+
48+
# Blend binary mask with the original denoise_mask using strength
49+
if strength and strength < 1:
50+
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
51+
return blended_mask
52+
else:
53+
return binary_mask
3554

3655

3756
NODE_CLASS_MAPPINGS = {

0 commit comments

Comments
 (0)