Skip to content

Commit 4136502

Browse files
implement APG guidance (Comfy-Org#8081)
* first pass at impementing AGP * rename, cleanup code * fix ruff * fix modified cond to match ref impl better, support different cond arity
1 parent 9ad287f commit 4136502

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

comfy_extras/nodes_apg.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
3+
def project(v0, v1):
4+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
5+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
6+
v0_orthogonal = v0 - v0_parallel
7+
return v0_parallel, v0_orthogonal
8+
9+
class APG:
10+
@classmethod
11+
def INPUT_TYPES(s):
12+
return {
13+
"required": {
14+
"model": ("MODEL",),
15+
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
16+
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
17+
"momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
18+
}
19+
}
20+
RETURN_TYPES = ("MODEL",)
21+
FUNCTION = "patch"
22+
CATEGORY = "sampling/custom_sampling"
23+
24+
def patch(self, model, eta, norm_threshold, momentum):
25+
running_avg = 0
26+
prev_sigma = None
27+
28+
def pre_cfg_function(args):
29+
nonlocal running_avg, prev_sigma
30+
31+
if len(args["conds_out"]) == 1: return args["conds_out"]
32+
33+
cond = args["conds_out"][0]
34+
uncond = args["conds_out"][1]
35+
sigma = args["sigma"][0]
36+
cond_scale = args["cond_scale"]
37+
38+
if prev_sigma is not None and sigma > prev_sigma:
39+
running_avg = 0
40+
prev_sigma = sigma
41+
42+
guidance = cond - uncond
43+
44+
if momentum > 0:
45+
if not torch.is_tensor(running_avg):
46+
running_avg = guidance
47+
else:
48+
running_avg = momentum * running_avg + guidance
49+
guidance = running_avg
50+
51+
if norm_threshold > 0:
52+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
53+
scale = torch.minimum(
54+
torch.ones_like(guidance_norm),
55+
norm_threshold / guidance_norm
56+
)
57+
guidance = guidance * scale
58+
59+
guidance_parallel, guidance_orthogonal = project(guidance, cond)
60+
modified_guidance = guidance_orthogonal + eta * guidance_parallel
61+
62+
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
63+
64+
return [modified_cond, uncond] + args["conds_out"][2:]
65+
66+
m = model.clone()
67+
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
68+
return (m,)
69+
70+
NODE_CLASS_MAPPINGS = {
71+
"APG": APG,
72+
}
73+
74+
NODE_DISPLAY_NAME_MAPPINGS = {
75+
"APG": "Adaptive Projected Guidance",
76+
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,7 @@ def init_builtin_extra_nodes():
22612261
"nodes_optimalsteps.py",
22622262
"nodes_hidream.py",
22632263
"nodes_fresca.py",
2264+
"nodes_apg.py",
22642265
"nodes_preview_any.py",
22652266
"nodes_ace.py",
22662267
"nodes_string.py",

0 commit comments

Comments
 (0)