Skip to content

Commit e5db48f

Browse files
committed
Support loading CosXL engines.
1 parent 7504509 commit e5db48f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tensorrt_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class TensorRTLoader:
110110
@classmethod
111111
def INPUT_TYPES(s):
112112
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
113-
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd"], ),
113+
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "cosxl"], ),
114114
}}
115115
RETURN_TYPES = ("MODEL",)
116116
FUNCTION = "load_unet"
@@ -142,6 +142,10 @@ def load_unet(self, unet_name, model_type):
142142
conf = comfy.supported_models.SVD_img2vid({})
143143
conf.unet_config["disable_unet_model_creation"] = True
144144
model = conf.get_model({})
145+
elif model_type == "cosxl":
146+
conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
147+
conf.unet_config["disable_unet_model_creation"] = True
148+
model = comfy.model_base.SDXL(conf, model_type=comfy.model_base.ModelType.V_PREDICTION_EDM)
145149
model.diffusion_model = unet
146150
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
147151

0 commit comments

Comments
 (0)