|
| 1 | +import torch |
| 2 | + |
| 3 | +def project(v0, v1): |
| 4 | + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) |
| 5 | + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 |
| 6 | + v0_orthogonal = v0 - v0_parallel |
| 7 | + return v0_parallel, v0_orthogonal |
| 8 | + |
| 9 | +class APG: |
| 10 | + @classmethod |
| 11 | + def INPUT_TYPES(s): |
| 12 | + return { |
| 13 | + "required": { |
| 14 | + "model": ("MODEL",), |
| 15 | + "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), |
| 16 | + "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), |
| 17 | + "momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), |
| 18 | + } |
| 19 | + } |
| 20 | + RETURN_TYPES = ("MODEL",) |
| 21 | + FUNCTION = "patch" |
| 22 | + CATEGORY = "sampling/custom_sampling" |
| 23 | + |
| 24 | + def patch(self, model, eta, norm_threshold, momentum): |
| 25 | + running_avg = 0 |
| 26 | + prev_sigma = None |
| 27 | + |
| 28 | + def pre_cfg_function(args): |
| 29 | + nonlocal running_avg, prev_sigma |
| 30 | + |
| 31 | + if len(args["conds_out"]) == 1: return args["conds_out"] |
| 32 | + |
| 33 | + cond = args["conds_out"][0] |
| 34 | + uncond = args["conds_out"][1] |
| 35 | + sigma = args["sigma"][0] |
| 36 | + cond_scale = args["cond_scale"] |
| 37 | + |
| 38 | + if prev_sigma is not None and sigma > prev_sigma: |
| 39 | + running_avg = 0 |
| 40 | + prev_sigma = sigma |
| 41 | + |
| 42 | + guidance = cond - uncond |
| 43 | + |
| 44 | + if momentum > 0: |
| 45 | + if not torch.is_tensor(running_avg): |
| 46 | + running_avg = guidance |
| 47 | + else: |
| 48 | + running_avg = momentum * running_avg + guidance |
| 49 | + guidance = running_avg |
| 50 | + |
| 51 | + if norm_threshold > 0: |
| 52 | + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) |
| 53 | + scale = torch.minimum( |
| 54 | + torch.ones_like(guidance_norm), |
| 55 | + norm_threshold / guidance_norm |
| 56 | + ) |
| 57 | + guidance = guidance * scale |
| 58 | + |
| 59 | + guidance_parallel, guidance_orthogonal = project(guidance, cond) |
| 60 | + modified_guidance = guidance_orthogonal + eta * guidance_parallel |
| 61 | + |
| 62 | + modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale |
| 63 | + |
| 64 | + return [modified_cond, uncond] + args["conds_out"][2:] |
| 65 | + |
| 66 | + m = model.clone() |
| 67 | + m.set_model_sampler_pre_cfg_function(pre_cfg_function) |
| 68 | + return (m,) |
| 69 | + |
| 70 | +NODE_CLASS_MAPPINGS = { |
| 71 | + "APG": APG, |
| 72 | +} |
| 73 | + |
| 74 | +NODE_DISPLAY_NAME_MAPPINGS = { |
| 75 | + "APG": "Adaptive Projected Guidance", |
| 76 | +} |
0 commit comments