|
| 1 | +from typing_extensions import override |
| 2 | +from comfy_api.latest import ComfyExtension, io |
1 | 3 | from comfy_api.torch_helpers import set_torch_compile_wrapper |
2 | 4 |
|
3 | 5 |
|
4 | | -class TorchCompileModel: |
| 6 | +class TorchCompileModel(io.ComfyNode): |
5 | 7 | @classmethod |
6 | | - def INPUT_TYPES(s): |
7 | | - return {"required": { "model": ("MODEL",), |
8 | | - "backend": (["inductor", "cudagraphs"],), |
9 | | - }} |
10 | | - RETURN_TYPES = ("MODEL",) |
11 | | - FUNCTION = "patch" |
| 8 | + def define_schema(cls) -> io.Schema: |
| 9 | + return io.Schema( |
| 10 | + node_id="TorchCompileModel", |
| 11 | + category="_for_testing", |
| 12 | + inputs=[ |
| 13 | + io.Model.Input("model"), |
| 14 | + io.Combo.Input( |
| 15 | + "backend", |
| 16 | + options=["inductor", "cudagraphs"], |
| 17 | + ), |
| 18 | + ], |
| 19 | + outputs=[io.Model.Output()], |
| 20 | + is_experimental=True, |
| 21 | + ) |
12 | 22 |
|
13 | | - CATEGORY = "_for_testing" |
14 | | - EXPERIMENTAL = True |
15 | | - |
16 | | - def patch(self, model, backend): |
| 23 | + @classmethod |
| 24 | + def execute(cls, model, backend) -> io.NodeOutput: |
17 | 25 | m = model.clone() |
18 | 26 | set_torch_compile_wrapper(model=m, backend=backend) |
19 | | - return (m, ) |
| 27 | + return io.NodeOutput(m) |
| 28 | + |
| 29 | + |
| 30 | +class TorchCompileExtension(ComfyExtension): |
| 31 | + @override |
| 32 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 33 | + return [ |
| 34 | + TorchCompileModel, |
| 35 | + ] |
| 36 | + |
20 | 37 |
|
21 | | -NODE_CLASS_MAPPINGS = { |
22 | | - "TorchCompileModel": TorchCompileModel, |
23 | | -} |
| 38 | +async def comfy_entrypoint() -> TorchCompileExtension: |
| 39 | + return TorchCompileExtension() |
0 commit comments