Skip to content

Commit 267befe

Browse files
authored
feat: support aten.trunc dynamo converter (#2543)
1 parent 088900d commit 267befe

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+22
Original file line numberDiff line numberDiff line change
@@ -2525,3 +2525,25 @@ def aten_ops_sort(
25252525
dim=args_bounds_check(args, 1, -1),
25262526
descending=args_bounds_check(args, 2, False),
25272527
)
2528+
2529+
2530+
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default)
2531+
@enforce_tensor_types(
2532+
{
2533+
0: (TRTTensor,),
2534+
}
2535+
)
2536+
def aten_ops_trunc(
2537+
ctx: ConversionContext,
2538+
target: Target,
2539+
args: Tuple[Argument, ...],
2540+
kwargs: Dict[str, Argument],
2541+
name: str,
2542+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2543+
return impl.unary.trunc(
2544+
ctx,
2545+
target,
2546+
SourceIR.ATEN,
2547+
name,
2548+
args[0],
2549+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
cast_trt_tensor,
10+
get_trt_tensor,
11+
)
912
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
1013
from torch_tensorrt.fx.types import TRTTensor
1114

@@ -432,3 +435,27 @@ def erf(
432435
return convert_unary(
433436
ctx, target, source_ir, name, trt.UnaryOperation.ERF, input_val
434437
)
438+
439+
440+
def trunc(
441+
ctx: ConversionContext,
442+
target: Target,
443+
source_ir: Optional[SourceIR],
444+
name: str,
445+
input_val: TRTTensor,
446+
) -> TRTTensor:
447+
if input_val.dtype not in (trt.float16, trt.float32):
448+
return impl.cast.to_copy(
449+
ctx,
450+
target,
451+
source_ir,
452+
f"{name}_copy",
453+
input_val,
454+
input_val.dtype,
455+
force_layer=True,
456+
)
457+
458+
dividend = get_trt_tensor(ctx, 1, f"{name}_dividend")
459+
return impl.elementwise.trunc_div(
460+
ctx, target, source_ir, f"{name}_trunc", input_val, dividend
461+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestTruncConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,),),
13+
((1, 20),),
14+
((2, 3, 4),),
15+
((2, 3, 4, 5),),
16+
]
17+
)
18+
def test_trunc_float(self, shape):
19+
class Trunc(nn.Module):
20+
def forward(self, input):
21+
return torch.ops.aten.trunc.default(input)
22+
23+
inputs = [torch.randn(shape)]
24+
self.run_test(
25+
Trunc(),
26+
inputs,
27+
enable_passes=True,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((10,),),
33+
((1, 20),),
34+
((2, 3, 4),),
35+
((2, 3, 4, 5),),
36+
]
37+
)
38+
def test_trunc_int(self, shape):
39+
class Trunc(nn.Module):
40+
def forward(self, input):
41+
return torch.ops.aten.trunc.default(input)
42+
43+
inputs = [torch.randint(-10, 10, shape, dtype=torch.int32)]
44+
self.run_test(
45+
Trunc(),
46+
inputs,
47+
enable_passes=True,
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)