@@ -110,7 +110,9 @@ class TensorRTLoader:
110110 @classmethod
111111 def INPUT_TYPES (s ):
112112 return {"required" : {"unet_name" : (folder_paths .get_filename_list ("tensorrt" ), ),
113- "model_type" : (["sdxl_base" , "sdxl_refiner" , "sd1.x" , "sd2.x-768v" , "svd" , "sd3" , "auraflow" ], ),
113+ "model_type" : (["sdxl_base" , "sdxl_refiner" ,
114+ "sd1.x" , "sd2.x-768v" , "svd" ,
115+ "sd3" , "auraflow" , "cosxl" ], ),
114116 }}
115117 RETURN_TYPES = ("MODEL" ,)
116118 FUNCTION = "load_unet"
@@ -150,6 +152,10 @@ def load_unet(self, unet_name, model_type):
150152 conf = comfy .supported_models .AuraFlow ({})
151153 conf .unet_config ["disable_unet_model_creation" ] = True
152154 model = conf .get_model ({})
155+ elif model_type == "cosxl" :
156+ conf = comfy .supported_models .SDXL ({"adm_in_channels" : 2816 })
157+ conf .unet_config ["disable_unet_model_creation" ] = True
158+ model = comfy .model_base .SDXL (conf , model_type = comfy .model_base .ModelType .V_PREDICTION_EDM )
153159 model .diffusion_model = unet
154160 model .memory_required = lambda * args , ** kwargs : 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
155161
@@ -159,4 +165,4 @@ def load_unet(self, unet_name, model_type):
159165
160166NODE_CLASS_MAPPINGS = {
161167 "TensorRTLoader" : TensorRTLoader ,
162- }
168+ }
0 commit comments