Skip to content

Commit c356238

Browse files
committed
Support loading CosXL engines.
1 parent 0a01ee8 commit c356238

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorrt_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class TensorRTLoader:
110110
@classmethod
111111
def INPUT_TYPES(s):
112112
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
113-
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3"], ),
113+
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "cosxl"], ),
114114
}}
115115
RETURN_TYPES = ("MODEL",)
116116
FUNCTION = "load_unet"
@@ -146,6 +146,10 @@ def load_unet(self, unet_name, model_type):
146146
conf = comfy.supported_models.SD3({})
147147
conf.unet_config["disable_unet_model_creation"] = True
148148
model = conf.get_model({})
149+
elif model_type == "cosxl":
150+
conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
151+
conf.unet_config["disable_unet_model_creation"] = True
152+
model = comfy.model_base.SDXL(conf, model_type=comfy.model_base.ModelType.V_PREDICTION_EDM)
149153
model.diffusion_model = unet
150154
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
151155

@@ -155,4 +159,4 @@ def load_unet(self, unet_name, model_type):
155159

156160
NODE_CLASS_MAPPINGS = {
157161
"TensorRTLoader": TensorRTLoader,
158-
}
162+
}

0 commit comments

Comments
 (0)