Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ jobs:
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
## Add your own personal access token to your Github Repository secrets and reference it here.
## Add your own personal access token to your GitHub Repository secrets and reference it here.
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
72 changes: 36 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TensorRT Node for ComfyUI

This node enables the best performance on NVIDIA RTX™ Graphics Cards
 (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT.
This node enables the best performance on NVIDIA RTX™ Graphics Cards(GPUs) for Stable Diffusion by leveraging NVIDIA
TensorRT.

Supports:

Expand All @@ -11,8 +11,7 @@ Supports:
- SDXL
- SDXL Turbo
- Stable Video Diffusion
- Stable Video Diffusion-XT 
- AuraFlow
- Stable Video Diffusion-XT- AuraFlow
- Flux

Requirements:
Expand All @@ -31,7 +30,8 @@ Requirements:
The recommended way to install these nodes is to use the [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager)
to easily install them to your ComfyUI instance.

You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the requirements like:
You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the
requirements like:

```
cd custom_nodes
Expand Down Expand Up @@ -62,7 +62,7 @@ You have the option to build either dynamic or static TensorRT engines:
Note: Most users will prefer dynamic engines, but static engines can be
useful if you use a specific resolution + batch size combination most of
the time. Static engines also require less VRAM; the wider the dynamic
range, the more VRAM will be consumed.
range, the more VRAM will be consumed.

## Instructions

Expand All @@ -71,19 +71,19 @@ These .json files can be loaded in ComfyUI.

### Building A TensorRT Engine From a Checkpoint

1. Add a Load Checkpoint Node
2. Add either a Static Model TensorRT Conversion node or a Dynamic
Model TensorRT Conversion node to ComfyUI
3. ![](readme_images/image3.png)
4. Connect the Load Checkpoint Model output to the TensorRT Conversion
Node Model input.
5. ![](readme_images/image5.png)
6. ![](readme_images/image2.png)
7. To help identify the converted TensorRT model, provide a meaningful
filename prefix, add this filename after “tensorrt/”
8. ![](readme_images/image9.png)

9. Click on Queue Prompt to start building the TensorRT Engines
1. Add a Load Checkpoint Node
2. Add either a Static Model TensorRT Conversion node or a Dynamic
Model TensorRT Conversion node to ComfyUI
3. ![](readme_images/image3.png)
4. Connect the Load Checkpoint Model output to the TensorRT Conversion
Node Model input.
5. ![](readme_images/image5.png)
6. ![](readme_images/image2.png)
7. To help identify the converted TensorRT model, provide a meaningful
filename prefix, add this filename after “tensorrt/”
8. ![](readme_images/image9.png)

9. Click on Queue Prompt to start building the TensorRT Engines
10. ![](readme_images/image7.png)

![](readme_images/image11.png)
Expand All @@ -96,9 +96,9 @@ the console.

![](readme_images/image4.png)

The first time generating an engine for a checkpoint will take awhile.
The first time generating an engine for a checkpoint will take a while.
Additional engines generated thereafter for the same checkpoint will be
much faster. Generating engines can take anywhere from 3-10 minutes for
much faster. Generating engines can take anywhere from 3-10 minutes for
the image generation models and 10-25 minutes for SVD. SVD-XT is an
extremely extensive model - engine build times may take up to an hour.

Expand All @@ -115,33 +115,33 @@ TensorRT Engines are loaded using the TensorRT Loader node.
ComfyUI TensorRT engines are not yet compatible with ControlNets or
LoRAs. Compatibility will be enabled in a future update.

1. Add a TensorRT Loader node
2. Note, if a TensorRT Engine has been created during a ComfyUI
session, it will not show up in the TensorRT Loader until the
ComfyUI interface has been refreshed (F5 to refresh browser).
3. ![](readme_images/image6.png)
4. Select a TensorRT Engine from the unet_name dropdown
5. Dynamic Engines will use a filename format of:
1. Add a TensorRT Loader node
2. Note, if a TensorRT Engine has been created during a ComfyUI
session, it will not show up in the TensorRT Loader until the
ComfyUI interface has been refreshed (F5 to refresh browser).
3. ![](readme_images/image6.png)
4. Select a TensorRT Engine from the unet_name dropdown
5. Dynamic Engines will use a filename format of:

 

1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt
2. dyn=dynamic, b=batch size, h=height, w=width
1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt
2. dyn=dynamic, b=batch size, h=height, w=width

 

6. Static Engine will use a filename format of:
6. Static Engine will use a filename format of:

 

1. stat-b-opt-h-opt-w-opt
2. stat=static, b=batch size, h=height, w=width
1. stat-b-opt-h-opt-w-opt
2. stat=static, b=batch size, h=height, w=width

 

7. ![](readme_images/image8.png)
8. The model_type must match the model type of the TensorRT engine.
9. ![](readme_images/image10.png)
7. ![](readme_images/image8.png)
8. The model_type must match the model type of the TensorRT engine.
9. ![](readme_images/image10.png)
10. The CLIP and VAE for the workflow will need to be utilized from the
original model checkpoint, the MODEL output from the TensorRT Loader
will be connected to the Sampler.
16 changes: 7 additions & 9 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .tensorrt_convert import DYNAMIC_TRT_MODEL_CONVERSION
from .tensorrt_convert import STATIC_TRT_MODEL_CONVERSION
from .tensorrt_loader import TrTUnet
from .tensorrt_loader import TensorRTLoader
from .onnx_nodes import NODE_CLASS_MAPPING as ONNX_CLASS_MAP
from .onnx_nodes import NODE_DISPLAY_NAME_MAPPINGS as ONNX_NAME_MAP
from .tensorrt_nodes import NODE_CLASS_MAPPINGS as TRT_CLASS_MAP
from .tensorrt_nodes import NODE_DISPLAY_NAME_MAPPINGS as TRT_NAME_MAP

NODE_CLASS_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": TensorRTLoader }
NODE_CLASS_MAPPINGS = TRT_CLASS_MAP | ONNX_CLASS_MAP
NODE_DISPLAY_NAME_MAPPINGS = TRT_NAME_MAP | ONNX_NAME_MAP


NODE_DISPLAY_NAME_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", "STATIC TRT_MODEL CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": "TensorRT Loader" }

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
9 changes: 9 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .baseline import TRTModelUtil
from .supported_models import (
supported_models,
unsupported_models,
detect_version_from_model,
get_helper_from_version,
get_helper_from_model,
get_model_from_version,
)
23 changes: 23 additions & 0 deletions models/auraflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .baseline import TRTModelUtil


class AuraFlow_TRT(TRTModelUtil):
def __init__(
self, context_dim=2048, input_channels=4, context_len=256, **kwargs
) -> None:
super().__init__(
context_dim=context_dim,
input_channels=input_channels,
context_len=context_len,
**kwargs,
)
self.is_conditional = True

@classmethod
def from_model(cls, model, **kwargs):
return cls(
context_dim=model.model.model_config.unet_config["cond_seq_dim"],
input_channels=model.model.diffusion_model.out_channels,
use_control=False,
**kwargs,
)
105 changes: 105 additions & 0 deletions models/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch


class TRTModelUtil:
def __init__(
self,
context_dim: int,
input_channels: int,
context_len: int,
use_control: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.context_dim = context_dim
self.input_channels = input_channels
self.context_len = context_len
self.use_control = use_control
self.is_conditional = False

self.input_config = {
"x": {
"batch": "{batch_size}",
"input_channels": self.input_channels,
"height": "{height}//8",
"width": "{width}//8",
},
"timesteps": {
"batch": "{batch_size}",
},
"context": {
"batch": "{batch_size}",
"context_len": "{context_len}",
"context_dim": self.context_dim,
},
}

self.output_config = {
"h": {
"batch": "{batch_size}",
"input_channels": self.input_channels,
"height": "{height}//8",
"width": "{width}//8",
}
}

def to_dict(self):
return {
"context_dim": self.context_dim,
"input_channels": self.input_channels,
"context_len": self.context_dim,
"use_control": self.use_control,
}

def get_input_names(self) -> list[str]:
return list(self.input_config.keys())

def get_output_names(self) -> list[str]:
return list(self.output_config.keys())

def get_dtype(self) -> torch.dtype:
return torch.float16

def get_input_shapes(self, **kwargs) -> dict:
inputs_shapes = {}
for io_name, io_config in self.input_config.items():
_inp = self._eval_shape(io_config, **kwargs)
inputs_shapes[io_name] = _inp

return inputs_shapes

def get_input_shapes_by_key(self, key: str, **kwargs) -> tuple[int]:
return self._eval_shape(self.input_config[key], **kwargs)

def get_dynamic_axes(self, config: dict = {}) -> dict:
dynamic_axes = {}

if config == {}:
config = self.input_config | self.output_config
for k, v in config.items():
dyn = {i: ax for i, (ax, s) in enumerate(v.items()) if isinstance(s, str)}
dynamic_axes[k] = dyn

return dynamic_axes

def _eval_shape(self, inp, **kwargs) -> tuple[int]:
if "context_len" not in kwargs:
kwargs["context_len"] = self.context_len
shape = []
for _, v in inp.items():
_s = v
if isinstance(v, str):
_s = int(eval(v.format(**kwargs)))
shape.append(_s)
return tuple(shape)

def get_control(self, *args, **kwargs) -> dict:
raise NotImplementedError

@classmethod
def from_model(cls, model, **kwargs):
raise NotImplementedError

def model_attributes(self, **kwargs) -> dict:
return {}
Loading