Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def __init__(self, model):
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask)
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
Expand Down
97 changes: 97 additions & 0 deletions comfy_extras/nodes_differential_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# code adapted from https://github.com/exx8/differential-diffusion

import torch
import inspect

class DifferentialDiffusion():
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False

@classmethod
def IS_CHANGED(s, *args, **kwargs):
DifferentialDiffusion.INIT = s.INIT = True
return ""

def __init__(self) -> None:
DifferentialDiffusion.INIT = False
self.sigmas: torch.Tensor = None
self.thresholds: torch.Tensor = None
self.mask_i = None
self.valid_sigmas = False
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]

def apply(self, model):
model = model.clone()
model.model_options["denoise_mask_function"] = self.forward
return (model,)

def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
self.__init__()
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
if self.sigmas is None:
self.sigmas = sigma[:1].repeat(2)
self.sigmas[-1].zero_()
self.sigmas_min = self.sigmas.min()
self.sigmas_max = self.sigmas.max()
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
self.thresholds_min_len = self.thresholds.shape[0] - 1
if self.valid_sigmas:
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
mask = denoise_mask.unsqueeze(0)
mask = (mask >= thresholds).to(denoise_mask.dtype)
self.mask_i = iter(mask)

def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
if self.sigmas is None or DifferentialDiffusion.INIT:
self.init_sigmas(sigma, denoise_mask)
if self.valid_sigmas:
try:
return next(self.mask_i)
except StopIteration:
self.valid_sigmas = False
if self.thresholds_min_len > 1:
nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
if not self.thresholds_min_len > nearest_idx:
nearest_idx = -2
threshold = self.thresholds[nearest_idx]
else:
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
return (denoise_mask >= threshold).to(denoise_mask.dtype)

def get_sigmas_and_sampler(frame, target):
found = frame.f_locals[target]
if isinstance(found, torch.Tensor) and found[-1] < 0.1:
return found, frame.f_code.co_name
return False

def find_outer_instance(target: str, target_type=None, callback=None):
frame = inspect.currentframe()
i = 0
while frame and i < 100:
if target in frame.f_locals:
if callback is not None:
res = callback(frame, target)
if res:
return res
else:
found = frame.f_locals[target]
if isinstance(found, target_type):
return found
frame = frame.f_back
i += 1
return None


NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,7 @@ def init_custom_nodes():
"nodes_photomaker.py",
"nodes_cond.py",
"nodes_stable_cascade.py",
"nodes_differential_diffusion.py",
]

for node_file in extras_files:
Expand Down