44import torch
55import comfy .utils
66import folder_paths
7+ from typing_extensions import override
8+ from comfy_api .latest import ComfyExtension , io
79
810try :
911 from spandrel_extra_arches import EXTRA_REGISTRY
1315except :
1416 pass
1517
16- class UpscaleModelLoader :
18+ class UpscaleModelLoader ( io . ComfyNode ) :
1719 @classmethod
18- def INPUT_TYPES (s ):
19- return {"required" : { "model_name" : (folder_paths .get_filename_list ("upscale_models" ), ),
20- }}
21- RETURN_TYPES = ("UPSCALE_MODEL" ,)
22- FUNCTION = "load_model"
20+ def define_schema (cls ):
21+ return io .Schema (
22+ node_id = "UpscaleModelLoader" ,
23+ display_name = "Load Upscale Model" ,
24+ category = "loaders" ,
25+ inputs = [
26+ io .Combo .Input ("model_name" , options = folder_paths .get_filename_list ("upscale_models" )),
27+ ],
28+ outputs = [
29+ io .UpscaleModel .Output (),
30+ ],
31+ )
2332
24- CATEGORY = "loaders"
25-
26- def load_model (self , model_name ):
33+ @classmethod
34+ def execute (cls , model_name ) -> io .NodeOutput :
2735 model_path = folder_paths .get_full_path_or_raise ("upscale_models" , model_name )
2836 sd = comfy .utils .load_torch_file (model_path , safe_load = True )
2937 if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd :
@@ -33,21 +41,27 @@ def load_model(self, model_name):
3341 if not isinstance (out , ImageModelDescriptor ):
3442 raise Exception ("Upscale model must be a single-image model." )
3543
36- return (out , )
44+ return io . NodeOutput (out )
3745
3846
39- class ImageUpscaleWithModel :
47+ class ImageUpscaleWithModel ( io . ComfyNode ) :
4048 @classmethod
41- def INPUT_TYPES (s ):
42- return {"required" : { "upscale_model" : ("UPSCALE_MODEL" ,),
43- "image" : ("IMAGE" ,),
44- }}
45- RETURN_TYPES = ("IMAGE" ,)
46- FUNCTION = "upscale"
49+ def define_schema (cls ):
50+ return io .Schema (
51+ node_id = "ImageUpscaleWithModel" ,
52+ display_name = "Upscale Image (using Model)" ,
53+ category = "image/upscaling" ,
54+ inputs = [
55+ io .UpscaleModel .Input ("upscale_model" ),
56+ io .Image .Input ("image" ),
57+ ],
58+ outputs = [
59+ io .Image .Output (),
60+ ],
61+ )
4762
48- CATEGORY = "image/upscaling"
49-
50- def upscale (self , upscale_model , image ):
63+ @classmethod
64+ def execute (cls , upscale_model , image ) -> io .NodeOutput :
5165 device = model_management .get_torch_device ()
5266
5367 memory_required = model_management .module_size (upscale_model .model )
@@ -75,9 +89,17 @@ def upscale(self, upscale_model, image):
7589
7690 upscale_model .to ("cpu" )
7791 s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 )
78- return (s ,)
92+ return io .NodeOutput (s )
93+
94+
95+ class UpscaleModelExtension (ComfyExtension ):
96+ @override
97+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
98+ return [
99+ UpscaleModelLoader ,
100+ ImageUpscaleWithModel ,
101+ ]
102+
79103
80- NODE_CLASS_MAPPINGS = {
81- "UpscaleModelLoader" : UpscaleModelLoader ,
82- "ImageUpscaleWithModel" : ImageUpscaleWithModel
83- }
104+ async def comfy_entrypoint () -> UpscaleModelExtension :
105+ return UpscaleModelExtension ()
0 commit comments