Skip to content

Commit c99569e

Browse files
committed
feat: Automatically generating converters for QDP plugins
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 302d0b8 commit c99569e

File tree

9 files changed

+802
-488
lines changed

9 files changed

+802
-488
lines changed

.pre-commit-config.yaml

+69-69
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,71 @@
11
exclude: ^.github/actions/assigner/dist
22
repos:
3-
- repo: https://github.com/pre-commit/pre-commit-hooks
4-
rev: v4.5.0
5-
hooks:
6-
- id: check-yaml
7-
- id: trailing-whitespace
8-
exclude: ^docs
9-
- id: check-added-large-files
10-
args:
11-
- --maxkb=1000
12-
- id: check-vcs-permalinks
13-
- id: check-merge-conflict
14-
- id: mixed-line-ending
15-
args:
16-
- --fix=lf
17-
exclude: ^docs
18-
- repo: https://github.com/pre-commit/mirrors-clang-format
19-
rev: v14.0.6
20-
hooks:
21-
- id: clang-format
22-
types_or: [c++, c, cuda]
23-
- repo: https://github.com/keith/pre-commit-buildifier
24-
rev: 6.4.0
25-
hooks:
26-
- id: buildifier
27-
args:
28-
- --warnings=all
29-
- id: buildifier-lint
30-
- repo: https://github.com/abravalheri/validate-pyproject
31-
rev: v0.16
32-
hooks:
33-
- id: validate-pyproject
34-
- repo: https://github.com/pycqa/isort
35-
rev: 5.13.2
36-
hooks:
37-
- id: isort
38-
name: isort (python)
39-
- repo: https://github.com/pre-commit/mirrors-mypy
40-
rev: 'v1.9.0'
41-
hooks:
42-
- id: mypy
43-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
44-
- repo: https://github.com/astral-sh/ruff-pre-commit
45-
# Ruff version.
46-
rev: v0.3.3
47-
hooks:
48-
- id: ruff
49-
- repo: https://github.com/psf/black
50-
rev: 24.3.0
51-
hooks:
52-
- id: black
53-
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
54-
- repo: https://github.com/crate-ci/typos
55-
rev: v1.22.9
56-
hooks:
57-
- id: typos
58-
- repo: https://github.com/astral-sh/uv-pre-commit
59-
# uv version.
60-
rev: 0.4.10
61-
hooks:
62-
# Update the uv lockfile
63-
- id: uv-lock
64-
- repo: local
65-
hooks:
66-
- id: dont-commit-upstream
67-
name: NVIDIA-INTERNAL check
68-
entry: "!NVIDIA-INTERNAL"
69-
exclude: "^.pre-commit-config.yaml"
70-
language: pygrep
71-
types: [text]
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v4.5.0
5+
hooks:
6+
- id: check-yaml
7+
- id: trailing-whitespace
8+
exclude: ^docs
9+
- id: check-added-large-files
10+
args:
11+
- --maxkb=1000
12+
- id: check-vcs-permalinks
13+
- id: check-merge-conflict
14+
- id: mixed-line-ending
15+
args:
16+
- --fix=lf
17+
exclude: ^docs
18+
- repo: https://github.com/pre-commit/mirrors-clang-format
19+
rev: v14.0.6
20+
hooks:
21+
- id: clang-format
22+
types_or: [c++, c, cuda]
23+
- repo: https://github.com/keith/pre-commit-buildifier
24+
rev: 6.4.0
25+
hooks:
26+
- id: buildifier
27+
args:
28+
- --warnings=all
29+
- id: buildifier-lint
30+
- repo: https://github.com/abravalheri/validate-pyproject
31+
rev: v0.23
32+
hooks:
33+
- id: validate-pyproject
34+
- repo: https://github.com/pycqa/isort
35+
rev: 5.13.2
36+
hooks:
37+
- id: isort
38+
name: isort (python)
39+
- repo: https://github.com/pre-commit/mirrors-mypy
40+
rev: "v1.9.0"
41+
hooks:
42+
- id: mypy
43+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
44+
- repo: https://github.com/astral-sh/ruff-pre-commit
45+
# Ruff version.
46+
rev: v0.3.3
47+
hooks:
48+
- id: ruff
49+
- repo: https://github.com/psf/black
50+
rev: 24.3.0
51+
hooks:
52+
- id: black
53+
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
54+
- repo: https://github.com/crate-ci/typos
55+
rev: v1.22.9
56+
hooks:
57+
- id: typos
58+
- repo: https://github.com/astral-sh/uv-pre-commit
59+
# uv version.
60+
rev: 0.5.5
61+
hooks:
62+
# Update the uv lockfile
63+
- id: uv-lock
64+
- repo: local
65+
hooks:
66+
- id: dont-commit-upstream
67+
name: NVIDIA-INTERNAL check
68+
entry: "!NVIDIA-INTERNAL"
69+
exclude: "^.pre-commit-config.yaml"
70+
language: pygrep
71+
types: [text]

MODULE.bazel

+9-11
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.
3636
new_local_repository(
3737
name = "cuda",
3838
build_file = "@//third_party/cuda:BUILD",
39-
path = "/usr/local/cuda-12.4/",
39+
path = "/usr/local/cuda-12.6/",
4040
)
4141

4242
new_local_repository(
4343
name = "cuda_win",
4444
build_file = "@//third_party/cuda:BUILD",
45-
path = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.4/",
45+
path = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/",
4646
)
4747

4848
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
@@ -55,21 +55,21 @@ http_archive(
5555
name = "libtorch",
5656
build_file = "@//third_party/libtorch:BUILD",
5757
strip_prefix = "libtorch",
58-
urls = ["https://download.pytorch.org/libtorch/nightly/cu124/libtorch-cxx11-abi-shared-with-deps-latest.zip"],
58+
urls = ["https://download.pytorch.org/libtorch/nightly/cu126/libtorch-cxx11-abi-shared-with-deps-latest.zip"],
5959
)
6060

6161
http_archive(
6262
name = "libtorch_pre_cxx11_abi",
6363
build_file = "@//third_party/libtorch:BUILD",
6464
strip_prefix = "libtorch",
65-
urls = ["https://download.pytorch.org/libtorch/nightly/cu124/libtorch-shared-with-deps-latest.zip"],
65+
urls = ["https://download.pytorch.org/libtorch/nightly/cu126/libtorch-shared-with-deps-latest.zip"],
6666
)
6767

6868
http_archive(
6969
name = "libtorch_win",
7070
build_file = "@//third_party/libtorch:BUILD",
7171
strip_prefix = "libtorch",
72-
urls = ["https://download.pytorch.org/libtorch/nightly/cu124/libtorch-win-shared-with-deps-latest.zip"],
72+
urls = ["https://download.pytorch.org/libtorch/nightly/cu126/libtorch-win-shared-with-deps-latest.zip"],
7373
)
7474

7575
# Download these tarballs manually from the NVIDIA website
@@ -79,20 +79,18 @@ http_archive(
7979
http_archive(
8080
name = "tensorrt",
8181
build_file = "@//third_party/tensorrt/archive:BUILD",
82-
sha256 = "33d3c2f3f4c84dc7991a4337a6fde9ed33f5c8e5c4f03ac2eb6b994a382b03a0",
83-
strip_prefix = "TensorRT-10.6.0.26",
82+
strip_prefix = "TensorRT-10.7.0.23",
8483
urls = [
85-
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz",
84+
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/tars/TensorRT-10.7.0.23.Linux.x86_64-gnu.cuda-12.6.tar.gz",
8685
],
8786
)
8887

8988
http_archive(
9089
name = "tensorrt_win",
9190
build_file = "@//third_party/tensorrt/archive:BUILD",
92-
sha256 = "6c6d92c108a1b3368423e8f69f08d31269830f1e4c9da43b37ba34a176797254",
93-
strip_prefix = "TensorRT-10.6.0.26",
91+
strip_prefix = "TensorRT-10.7.0.23",
9492
urls = [
95-
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/zip/TensorRT-10.6.0.26.Windows.win10.cuda-12.6.zip",
93+
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/zip/TensorRT-10.7.0.23.Windows.win10.cuda-12.6.zip",
9694
],
9795
)
9896

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
.. _auto_generate_converters:
3+
4+
Automatically Generate a Converter for a Custom Kernel
5+
===================================================================
6+
7+
We are going to demonstrate how to automatically generate a converter for a custom kernel using Torch-TensorRT using
8+
the new Python based plugin system in TensorRT 10.7.
9+
10+
Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
11+
does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.
12+
The easiest way to fix lack of support for ops is by adding a decomposition (see:
13+
`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator
14+
in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:
15+
`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators.
16+
17+
In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or
18+
TensorRT cannot support it natively.
19+
20+
For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding
21+
the performance and resource overhead from a graph break.
22+
23+
Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
24+
With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
25+
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
26+
operation in PyTorch to TensorRT.
27+
"""
28+
29+
# %%
30+
# Writing Custom Operators in PyTorch
31+
# -----------------------------------------
32+
#
33+
# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
34+
# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
35+
# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
36+
# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
37+
# is necessary to define.
38+
#
39+
40+
from typing import Tuple
41+
42+
import tensorrt_bindings.plugin as trtp
43+
import torch
44+
import torch_tensorrt
45+
import triton
46+
import triton.language as tl
47+
48+
49+
@triton.jit
50+
def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
51+
# Program ID determines the block of data each thread will process
52+
pid = tl.program_id(0)
53+
# Compute the range of elements that this thread block will work on
54+
block_start = pid * BLOCK_SIZE
55+
# Range of indices this thread will handle
56+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
57+
# Load elements from the X and Y tensors
58+
x_vals = tl.load(X + offsets)
59+
y_vals = tl.load(Y + offsets)
60+
# Perform the element-wise multiplication
61+
z_vals = x_vals * y_vals
62+
# Store the result in Z
63+
tl.store(Z + offsets, z_vals)
64+
65+
66+
@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc]
67+
def elementwise_mul(
68+
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
69+
) -> torch.Tensor:
70+
# Ensure the tensors are on the GPU
71+
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
72+
assert X.shape == Y.shape, "Tensors must have the same shape."
73+
74+
# Create output tensor
75+
Z = torch.empty_like(X)
76+
77+
# Define block size
78+
BLOCK_SIZE = 1024
79+
80+
# Grid of programs
81+
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
82+
83+
# Launch the kernel
84+
elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
85+
86+
return Z
87+
88+
89+
# %%
90+
# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape
91+
# in the course of the operation.
92+
93+
94+
@torch.library.register_fake("torchtrt_ex::elementwise_mul")
95+
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
96+
return x
97+
98+
99+
# %%
100+
# Writing Plugins for TensorRT using the Quick Deploy Plugin system
101+
# -------------------------------------------------------------------
102+
# The quick deployment plugin system in TensorRT 10.7 allows for the creation of custom plugins in Python with significantly
103+
# less boilerplate. It uses a similar system PyTorch where you define a function that describes the shape and data type transformations
104+
# that the operator will perform and then define the code to launch the kernel given GPU memory handles.
105+
#
106+
107+
108+
# %%
109+
# Just like the PyTorch meta kernel, there is no transformation in shape or data type between the input and output so
110+
# we can just tell TensorRT to expect the same shape as we get in
111+
#
112+
@trtp.register("torchtrt_ex::elementwise_mul")
113+
def _(
114+
x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int
115+
) -> Tuple[trtp.TensorDesc]:
116+
return x.like()
117+
118+
119+
# %%
120+
# Here we reuse similar host launch code as PyTorch but we need to convert the TensorRT tensors into PyTorch tensors prior to launching the kernel
121+
# These operations are also in-place, so the result must be put in the the output tensors provided by TensorRT.
122+
@trtp.impl("torchtrt_ex::elementwise_mul")
123+
def _(
124+
x: trtp.Tensor,
125+
y: trtp.Tensor,
126+
b: float,
127+
a: int,
128+
outputs: Tuple[trtp.Tensor],
129+
stream: int,
130+
):
131+
# Define block size
132+
BLOCK_SIZE = 1024
133+
134+
# Grid of programs
135+
grid = lambda meta: (x.numel() // meta["BLOCK_SIZE"],)
136+
137+
x_t = torch.as_tensor(x, device="cuda")
138+
y_t = torch.as_tensor(y, device="cuda")
139+
z_t = torch.as_tensor(outputs[0], device="cuda")
140+
# Launch the kernel
141+
elementwise_mul_kernel[grid](x_t, y_t, z_t, BLOCK_SIZE=BLOCK_SIZE)
142+
143+
144+
# %%
145+
# Generating the Converter
146+
# -------------------------------------------------------------------
147+
# Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
148+
# As long as the namespace and names match, the following function will automatically generate the converter for the operation.
149+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
150+
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
151+
)
152+
153+
154+
# %%
155+
# Using our converter with a model
156+
# -------------------------------------------------------------------
157+
#
158+
# Now we can use our custom operator in a model and compile it with Torch-TensorRT.
159+
# We can see that the custom operator is used as one of the operations in the forward pass of the model.
160+
# The process of compiling the model at this point is identical to standard Torch-TensorRT usage.
161+
class MyModel(torch.nn.Module): # type: ignore[misc]
162+
def __init__(self):
163+
super().__init__()
164+
165+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166+
z = torch.add(x, y)
167+
res = torch.ops.torchtrt_ex.elementwise_mul.default(x, z, a=1)
168+
169+
return res
170+
171+
172+
my_model = MyModel().to("cuda")
173+
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
174+
n = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
175+
176+
with torch_tensorrt.logging.errors():
177+
model_trt = torch_tensorrt.compile(
178+
my_model, inputs=[m, n], debug=True, min_block_size=1
179+
)
180+
for i in range(300):
181+
res = model_trt(m, n)
182+
assert torch.allclose(res, my_model(m, n))
183+
184+
print("Ran with custom plugin!")

py/torch_tensorrt/dynamo/conversion/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
1+
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
22
from ._conversion import convert_module, interpret_module_to_result
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403

0 commit comments

Comments
 (0)