Skip to content

Commit 9615b1c

Browse files
committed
Update usage of PyTorch's custom op API
Hi, I maintain the custom ops story in PyTorch. This PR updates the the usage of PyTorch's private custom op API to a newer API. This API is still private but closer to what we want it to be. Test Plan: - wait for CI
1 parent c3a65ef commit 9615b1c

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
from typing import Any, Dict, Optional, Sequence, Tuple
22

33
import torch
4-
from torch._custom_op.impl import custom_op
4+
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
66
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
77
from torch_tensorrt.fx.converter_registry import tensorrt_converter
88
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
99
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1010

1111

12-
@custom_op(
13-
qualname="tensorrt::einsum",
14-
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
12+
library.custom_op(
13+
"tensorrt::einsum",
14+
"(str equation, Tensor[] tensors) -> Tensor",
1515
)
16-
def einsum(equation, tensors): # type: ignore[no-untyped-def]
17-
# Defines operator schema, name, namespace, and function header
18-
...
1916

20-
21-
@einsum.impl("cpu") # type: ignore[misc]
22-
@einsum.impl("cuda") # type: ignore[misc]
23-
@einsum.impl_abstract() # type: ignore[misc]
17+
@library.impl("tensorrt::einsum") # type: ignore[misc]
18+
@library.impl_abstract("tensorrt::einsum") # type: ignore[misc]
2419
def einsum_generic(
2520
*args: Any,
2621
**kwargs: Any,

py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, Optional, Tuple
22

33
import torch
4-
from torch._custom_op.impl import custom_op
4+
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
66
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
77
from torch_tensorrt.fx.converter_registry import tensorrt_converter
@@ -20,14 +20,10 @@
2020
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
2121
# Then, create a placeholder function with no operations, but having the same schema and naming as that
2222
# used in the decorator
23-
@custom_op(
24-
qualname="tensorrt::maxpool1d",
25-
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
23+
library.custom_op(
24+
"tensorrt::maxpool1d",
25+
"(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor"
2626
)
27-
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ignore[no-untyped-def]
28-
# Defines operator schema, name, namespace, and function header
29-
...
30-
3127

3228
# 2. The Generic Implementation
3329
#
@@ -36,9 +32,8 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ig
3632
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
3733
# implementation here. Note that the function header to the generic function can have specific arguments
3834
# as in the above placeholder
39-
@maxpool1d.impl("cpu") # type: ignore[misc]
40-
@maxpool1d.impl("cuda") # type: ignore[misc]
41-
@maxpool1d.impl_abstract() # type: ignore[misc]
35+
@library.impl("tensorrt::maxpool1d") # type: ignore[misc]
36+
@library.impl_abstract("tensorrt::maxpool1d") # type: ignore[misc]
4237
def maxpool1d_generic(
4338
*args: Any,
4439
**kwargs: Any,

0 commit comments

Comments
 (0)