Skip to content

Commit 6a69c6a

Browse files
authored
feat: Add _to_copy, operator.get and clone ATen converters (#2161)
1 parent 06e544e commit 6a69c6a

File tree

14 files changed

+352
-40
lines changed

14 files changed

+352
-40
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
2828

2929

30+
class UnsupportedOperatorException(RuntimeError):
31+
pass
32+
33+
3034
class TRTInterpreterResult(NamedTuple):
3135
engine: Any
3236
input_names: Sequence[str]
@@ -301,7 +305,7 @@ def call_module(
301305
converter = CONVERTERS.get(self._cur_node)
302306

303307
if not converter:
304-
raise RuntimeError(
308+
raise UnsupportedOperatorException(
305309
f"Conversion of module of type {submod_type} not currently supported!"
306310
)
307311

@@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
312316
# TODO: Why is this stateful? We should be able to take in the inputs
313317
converter = CONVERTERS.get(self._cur_node)
314318
if not converter:
315-
raise RuntimeError(
319+
raise UnsupportedOperatorException(
316320
f"Conversion of function {torch.typename(target)} not currently supported!"
317321
)
318322

@@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
324328
converter = CONVERTERS.get(self._cur_node)
325329

326330
if not converter:
327-
raise RuntimeError(
331+
raise UnsupportedOperatorException(
328332
f"Conversion of method {target} not currently supported!"
329333
)
330334

Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._TRTInterpreter import * # noqa: F403
22
from .aten_ops_converters import * # noqa: F403
33
from .conversion import * # noqa: F403
4+
from .op_evaluators import * # noqa: F403
45
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+60-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4+
import tensorrt as trt
45
import torch
56
from torch.fx.node import Argument, Node, Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -12,8 +13,6 @@
1213
from torch_tensorrt.fx.converters import acc_ops_converters
1314
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1415

15-
import tensorrt as trt
16-
1716
from .converter_registry import dynamo_tensorrt_converter
1817

1918
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -76,13 +75,13 @@ def aten_ops_div(
7675
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
7776
):
7877
kwargs_new["input"] = cast_trt_tensor(
79-
network, kwargs_new["input"], trt.float32, name
78+
network, kwargs_new["input"], trt.float32, name, target
8079
)
8180
elif isinstance(args[1], TRTTensor) and (
8281
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
8382
):
8483
kwargs_new["other"] = cast_trt_tensor(
85-
network, kwargs_new["other"], trt.float32, name
84+
network, kwargs_new["other"], trt.float32, name, target
8685
)
8786
rounding_mode = kwargs.get("rounding_mode")
8887
if rounding_mode is None:
@@ -101,7 +100,7 @@ def aten_ops_div(
101100
)
102101

103102

104-
def embedding_param_validator(embedding_node: Node):
103+
def embedding_param_validator(embedding_node: Node) -> bool:
105104
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
106105
sparse = args_bounds_check(embedding_node.args, 4)
107106

@@ -365,3 +364,59 @@ def aten_ops_permute(
365364
args[0],
366365
args[1],
367366
)
367+
368+
369+
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
370+
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
371+
372+
# Validate input node has convertible kwargs
373+
if "dtype" in to_copy_node.kwargs:
374+
if to_copy_node.kwargs["dtype"] in allowed_casts:
375+
return True
376+
else:
377+
_LOGGER.debug(
378+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
379+
)
380+
return False
381+
else:
382+
_LOGGER.debug(
383+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
384+
)
385+
return False
386+
387+
388+
@dynamo_tensorrt_converter(
389+
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
390+
)
391+
def aten_ops_to_copy_dtype(
392+
network: TRTNetwork,
393+
target: Target,
394+
args: Tuple[Argument, ...],
395+
kwargs: Dict[str, Argument],
396+
name: str,
397+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
398+
return impl.cast.to_copy(
399+
network,
400+
target,
401+
SourceIR.ATEN,
402+
name,
403+
args[0],
404+
kwargs["dtype"],
405+
)
406+
407+
408+
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
409+
def aten_ops_clone(
410+
network: TRTNetwork,
411+
target: Target,
412+
args: Tuple[Argument, ...],
413+
kwargs: Dict[str, Argument],
414+
name: str,
415+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
416+
return impl.cast.clone(
417+
network,
418+
target,
419+
SourceIR.ATEN,
420+
name,
421+
args[0],
422+
)

py/torch_tensorrt/dynamo/conversion/converter_registry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dynamo_tensorrt_converter(
6666
enabled: bool = True,
6767
capability_validator: Optional[Callable[[Node], bool]] = None,
6868
priority: ConverterPriority = ConverterPriority.STANDARD,
69-
) -> Callable[[Any], Any]:
69+
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
7070
"""Decorator for Dynamo TensorRT Converter
7171
7272
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import logging
22
import re
3-
from typing import List
3+
from typing import List, Optional
44

55
import tensorrt as trt
66
import torch
7+
from torch.fx.node import Target
78
from torch_tensorrt.fx.converters.converter_utils import (
89
Frameworks,
910
unified_dtype_converter,
1011
)
1112
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
1213

14+
from .._SourceIR import SourceIR
15+
from .converter_registry import ConverterRegistry
16+
1317
_LOGGER: logging.Logger = logging.getLogger(__name__)
1418

1519

@@ -71,24 +75,32 @@ def cast_trt_tensor(
7175
input_val: TRTTensor,
7276
dtype: TRTDataType,
7377
name: str,
78+
target: Target = "",
79+
source_ir: Optional[SourceIR] = None,
7480
) -> TRTTensor:
7581
"""
7682
Given a TRT Tensor, convert that Tensor to the specified dtype
7783
Adds an Identity layer to the network which performs the conversion
7884
Args:
7985
network (TRTNetwork): A TensorRT network
8086
input_val (TRTTensor): A TRT Tensor to cast to a new data type
81-
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
87+
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
8288
name (str): Name of the calling layer
89+
target (Target): Target of calling node
90+
source_ir (SourceIR): SourceIR of calling converter
8391
Returns:
8492
A TensorRT ITensor which has been casted to the specified dtype
8593
"""
8694
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
8795

8896
if input_val.dtype != trt_dtype:
97+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
98+
target_str = ConverterRegistry.qualified_name_or_str(target)
99+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
100+
89101
identity_layer = network.add_identity(input_val)
90102
identity_layer.set_output_type(0, trt_dtype)
91-
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
103+
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
92104
return identity_layer.get_output(0)
93105
else:
94106
return input_val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import (
44
activation,
5+
cast,
56
condition,
67
elementwise,
78
embedding,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import logging
2+
from typing import Optional
3+
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
7+
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
8+
9+
LOGGER: logging.Logger = logging.getLogger(__name__)
10+
11+
12+
def to_copy(
13+
network: TRTNetwork,
14+
target: Target,
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
input: TRTTensor,
18+
dtype: TRTDataType,
19+
) -> TRTTensor:
20+
if not isinstance(input, TRTTensor):
21+
raise RuntimeError(
22+
f"to_copy received input {input} that is not a TensorRT ITensor"
23+
)
24+
25+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26+
return casted_tensor
27+
28+
29+
def clone(
30+
network: TRTNetwork,
31+
target: Target,
32+
source_ir: Optional[SourceIR],
33+
name: str,
34+
input: TRTTensor,
35+
) -> TRTTensor:
36+
if not isinstance(input, TRTTensor):
37+
raise RuntimeError(
38+
f"clone received input {input} that is not a TensorRT ITensor"
39+
)
40+
41+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42+
43+
return input

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -15,8 +16,6 @@
1516
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
1617
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1718

18-
import tensorrt as trt
19-
2019

2120
def get_python_op_from_trt_elementwise_op(
2221
trt_op: TRTElementWiseOp,
@@ -132,9 +131,13 @@ def convert_binary_elementwise(
132131
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)
133132

134133
if trt_promoted_type != lhs_val.dtype:
135-
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
134+
lhs_val = cast_trt_tensor(
135+
network, lhs_val, trt_promoted_type, name, target, source_ir
136+
)
136137
if trt_promoted_type != rhs_val.dtype:
137-
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
138+
rhs_val = cast_trt_tensor(
139+
network, rhs_val, trt_promoted_type, name, target, source_ir
140+
)
138141

139142
# Check the limitation in the doc string.
140143
if network.has_implicit_batch_dimension:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
import operator
3+
from typing import Dict, Sequence, Tuple, Union
4+
5+
from torch.fx.node import Argument, Node, Target
6+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
7+
8+
from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter
9+
10+
_LOGGER: logging.Logger = logging.getLogger(__name__)
11+
12+
13+
def getitem_validator(getitem_node: Node) -> bool:
14+
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS
15+
16+
# Getitem nodes can only be converted if their parent node also can
17+
return getitem_node.args[0] in DYNAMO_CONVERTERS
18+
19+
20+
# TODO: Subsequent evaluators should be registered here with their own validators
21+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
def generic_evaluator(
23+
network: TRTNetwork,
24+
target: Target,
25+
args: Tuple[Argument, ...],
26+
kwargs: Dict[str, Argument],
27+
name: str,
28+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29+
_LOGGER.debug(
30+
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
31+
)
32+
return target(*args)

py/torch_tensorrt/dynamo/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from typing import Any, Callable, Dict, Optional, Sequence
66

77
import torch
8+
import torch_tensorrt
89
from torch_tensorrt._Device import Device
910
from torch_tensorrt._Input import Input
1011
from torch_tensorrt.dynamo import CompilationSettings
12+
from torch_tensorrt.dynamo._defaults import PRECISION
1113

1214
from packaging import version
1315

@@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
161163
if settings.debug:
162164
logger.setLevel(logging.DEBUG)
163165

166+
# TODO: Remove once Dynamo precisions refactoring is complete
167+
if "enabled_precisions" in kwargs:
168+
enabled_precisions = kwargs["enabled_precisions"]
169+
170+
if (
171+
torch.float16 in enabled_precisions
172+
or torch_tensorrt.dtype.half in enabled_precisions
173+
):
174+
settings.precision = torch.float16
175+
elif (
176+
torch.float32 in enabled_precisions
177+
or torch_tensorrt.dtype.float in enabled_precisions
178+
):
179+
settings.precision = torch.float32
180+
elif len(enabled_precisions) == 0:
181+
logger.info(f"No precision specified, defaulting to {PRECISION}")
182+
settings.precision = PRECISION
183+
else:
184+
raise ValueError(
185+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
186+
)
187+
164188
# Parse input runtime specification
165189
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
166190

0 commit comments

Comments
 (0)