|
3 | 3 |
|
4 | 4 | #My modified one here is more basic but has less chances of breaking with ComfyUI updates. |
5 | 5 |
|
| 6 | +from typing_extensions import override |
| 7 | + |
6 | 8 | import comfy.model_patcher |
7 | 9 | import comfy.samplers |
| 10 | +from comfy_api.latest import ComfyExtension, io |
8 | 11 |
|
9 | | -class PerturbedAttentionGuidance: |
10 | | - @classmethod |
11 | | - def INPUT_TYPES(s): |
12 | | - return { |
13 | | - "required": { |
14 | | - "model": ("MODEL",), |
15 | | - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), |
16 | | - } |
17 | | - } |
18 | | - |
19 | | - RETURN_TYPES = ("MODEL",) |
20 | | - FUNCTION = "patch" |
21 | 12 |
|
22 | | - CATEGORY = "model_patches/unet" |
| 13 | +class PerturbedAttentionGuidance(io.ComfyNode): |
| 14 | + @classmethod |
| 15 | + def define_schema(cls): |
| 16 | + return io.Schema( |
| 17 | + node_id="PerturbedAttentionGuidance", |
| 18 | + category="model_patches/unet", |
| 19 | + inputs=[ |
| 20 | + io.Model.Input("model"), |
| 21 | + io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), |
| 22 | + ], |
| 23 | + outputs=[ |
| 24 | + io.Model.Output(), |
| 25 | + ], |
| 26 | + ) |
23 | 27 |
|
24 | | - def patch(self, model, scale): |
| 28 | + @classmethod |
| 29 | + def execute(cls, model, scale) -> io.NodeOutput: |
25 | 30 | unet_block = "middle" |
26 | 31 | unet_block_id = 0 |
27 | 32 | m = model.clone() |
@@ -49,8 +54,16 @@ def post_cfg_function(args): |
49 | 54 |
|
50 | 55 | m.set_model_sampler_post_cfg_function(post_cfg_function) |
51 | 56 |
|
52 | | - return (m,) |
| 57 | + return io.NodeOutput(m) |
| 58 | + |
| 59 | + |
| 60 | +class PAGExtension(ComfyExtension): |
| 61 | + @override |
| 62 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 63 | + return [ |
| 64 | + PerturbedAttentionGuidance, |
| 65 | + ] |
| 66 | + |
53 | 67 |
|
54 | | -NODE_CLASS_MAPPINGS = { |
55 | | - "PerturbedAttentionGuidance": PerturbedAttentionGuidance, |
56 | | -} |
| 68 | +async def comfy_entrypoint() -> PAGExtension: |
| 69 | + return PAGExtension() |
0 commit comments