Skip to content

Commit de8faa2

Browse files
authored
Reorg for converters leaky_relu (FX Converter Refactor [6/N]) <Target: converter_reorg_proto> (#1902)
1 parent ed20f09 commit de8faa2

File tree

5 files changed

+126
-11
lines changed

5 files changed

+126
-11
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,17 +1026,9 @@ def acc_ops_leaky_relu(
10261026
kwargs: Dict[str, Argument],
10271027
name: str,
10281028
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1029-
input_val = kwargs["input"]
1030-
negative_slope = kwargs["negative_slope"]
1031-
operation_type = trt.ActivationType.LEAKY_RELU
1032-
return activation.convert_activation(
1033-
network,
1034-
target,
1035-
SourceIR.ACC,
1036-
name,
1037-
operation_type,
1038-
input_val,
1039-
alpha=negative_slope,
1029+
1030+
return activation.leaky_relu(
1031+
network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"]
10401032
)
10411033

10421034

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,33 @@ def aten_ops_hardtanh(
215215
)
216216

217217

218+
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
219+
def aten_ops_fmod(
220+
network: TRTNetwork,
221+
target: Target,
222+
args: Tuple[Argument, ...],
223+
kwargs: Dict[str, Argument],
224+
name: str,
225+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
226+
kwargs_new = {
227+
"input": args[0],
228+
"other": args[1],
229+
}
230+
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
231+
232+
233+
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
234+
def aten_ops_leaky_relu(
235+
network: TRTNetwork,
236+
target: Target,
237+
args: Tuple[Argument, ...],
238+
kwargs: Dict[str, Argument],
239+
name: str,
240+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
241+
242+
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
243+
244+
218245
@tensorrt_converter(torch.ops.aten.linear)
219246
def aten_ops_linear(
220247
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,30 @@ def tanh_fn(x):
175175
input_val,
176176
dyn_range_fn=tanh_dyn_range_fn,
177177
)
178+
179+
180+
def leaky_relu(
181+
network: TRTNetwork,
182+
target: Target,
183+
source_ir: Optional[SourceIR],
184+
name: str,
185+
input_val: TRTTensor,
186+
alpha: Optional[Any],
187+
):
188+
operation_type = trt.ActivationType.LEAKY_RELU
189+
190+
def leaky_relu_dyn_range_fn(dyn_range):
191+
return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), (
192+
max(0, dyn_range[1]) + alpha * min(0, dyn_range[1])
193+
)
194+
195+
return convert_activation(
196+
network,
197+
target,
198+
source_ir,
199+
name,
200+
operation_type,
201+
input_val,
202+
alpha,
203+
dyn_range_fn=leaky_relu_dyn_range_fn,
204+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,19 @@ def tanh(network, submod, args, kwargs, layer_name):
6666
name=layer_name,
6767
input_val=kwargs["input"],
6868
)
69+
70+
71+
@tensorrt_converter(torch.nn.functional.leaky_relu)
72+
@tensorrt_converter(torch.nn.modules.activation.LeakyReLU)
73+
def leaky_relu(network, submod, args, kwargs, layer_name):
74+
# args/kwargs should have already been normalized to kwargs
75+
assert len(args) == 0
76+
77+
return activation.leaky_relu(
78+
network=network,
79+
target="torch.nn.functional.leaky_relu",
80+
source_ir=SourceIR.NN,
81+
name=layer_name,
82+
input_val=kwargs["input"],
83+
alpha=kwargs["negative_slope"],
84+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestLeakyReLUConverter(DispatchTestCase):
8+
def test_leaky_relu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.leaky_relu(x, negative_slope=0.05)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default}
16+
)
17+
18+
def test_leaky_relu_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.leaky_relu(x, negative_slope=0.05)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
32+
)
33+
34+
def test_leaky_relu_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.leaky_relu(x, negative_slope=0.05)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)