Skip to content

Commit 83bf8b7

Browse files
committed
Support loading CosXL engines.
1 parent 81432b6 commit 83bf8b7

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorrt_loader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ class TensorRTLoader:
110110
@classmethod
111111
def INPUT_TYPES(s):
112112
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
113-
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow"], ),
113+
"model_type": (["sdxl_base", "sdxl_refiner",
114+
"sd1.x", "sd2.x-768v", "svd",
115+
"sd3", "auraflow", "cosxl"], ),
114116
}}
115117
RETURN_TYPES = ("MODEL",)
116118
FUNCTION = "load_unet"
@@ -150,6 +152,10 @@ def load_unet(self, unet_name, model_type):
150152
conf = comfy.supported_models.AuraFlow({})
151153
conf.unet_config["disable_unet_model_creation"] = True
152154
model = conf.get_model({})
155+
elif model_type == "cosxl":
156+
conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
157+
conf.unet_config["disable_unet_model_creation"] = True
158+
model = comfy.model_base.SDXL(conf, model_type=comfy.model_base.ModelType.V_PREDICTION_EDM)
153159
model.diffusion_model = unet
154160
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
155161

@@ -159,4 +165,4 @@ def load_unet(self, unet_name, model_type):
159165

160166
NODE_CLASS_MAPPINGS = {
161167
"TensorRTLoader": TensorRTLoader,
162-
}
168+
}

0 commit comments

Comments
 (0)