@@ -110,7 +110,7 @@ class TensorRTLoader:
110110 @classmethod
111111 def INPUT_TYPES (s ):
112112 return {"required" : {"unet_name" : (folder_paths .get_filename_list ("tensorrt" ), ),
113- "model_type" : (["sdxl_base" , "sdxl_refiner" , "sd1.x" , "sd2.x-768v" , "svd" , "sd3" ], ),
113+ "model_type" : (["sdxl_base" , "sdxl_refiner" , "sd1.x" , "sd2.x-768v" , "svd" , "sd3" , "cosxl" ], ),
114114 }}
115115 RETURN_TYPES = ("MODEL" ,)
116116 FUNCTION = "load_unet"
@@ -146,6 +146,10 @@ def load_unet(self, unet_name, model_type):
146146 conf = comfy .supported_models .SD3 ({})
147147 conf .unet_config ["disable_unet_model_creation" ] = True
148148 model = conf .get_model ({})
149+ elif model_type == "cosxl" :
150+ conf = comfy .supported_models .SDXL ({"adm_in_channels" : 2816 })
151+ conf .unet_config ["disable_unet_model_creation" ] = True
152+ model = comfy .model_base .SDXL (conf , model_type = comfy .model_base .ModelType .V_PREDICTION_EDM )
149153 model .diffusion_model = unet
150154 model .memory_required = lambda * args , ** kwargs : 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
151155
@@ -155,4 +159,4 @@ def load_unet(self, unet_name, model_type):
155159
156160NODE_CLASS_MAPPINGS = {
157161 "TensorRTLoader" : TensorRTLoader ,
158- }
162+ }
0 commit comments