Skip to content

Commit 60169bb

Browse files
committed
feat: support activation dynamo converters
lint test file fix bugs: circular import delete gelu change function calls from nn.Module to nn.functional
1 parent 0e5a497 commit 60169bb

File tree

9 files changed

+713
-81
lines changed

9 files changed

+713
-81
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+128-18
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ def aten_ops_fmod(
152152
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
153153

154154

155-
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) # type: ignore[misc]
156-
def aten_ops_gelu(
155+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
156+
def aten_ops_relu(
157157
network: TRTNetwork,
158158
target: Target,
159159
args: Tuple[Argument, ...],
160160
kwargs: Dict[str, Argument],
161161
name: str,
162162
) -> Union[TRTTensor, Sequence[TRTTensor]]:
163-
return impl.activation.gelu(
163+
return impl.activation.relu(
164164
network,
165165
target,
166166
SourceIR.ATEN,
@@ -169,55 +169,165 @@ def aten_ops_gelu(
169169
)
170170

171171

172-
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173-
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174-
def aten_ops_matmul(
172+
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default)
173+
def aten_ops_sigmoid(
175174
network: TRTNetwork,
176175
target: Target,
177176
args: Tuple[Argument, ...],
178177
kwargs: Dict[str, Argument],
179178
name: str,
180179
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181-
return impl.matmul.matrix_multiply(
182-
network, target, SourceIR.ATEN, name, args[0], args[1]
180+
return impl.activation.sigmoid(
181+
network,
182+
target,
183+
SourceIR.ATEN,
184+
name,
185+
args[0],
183186
)
184187

185188

186-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
187-
def aten_ops_layernorm(
189+
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
190+
def aten_ops_tanh(
188191
network: TRTNetwork,
189192
target: Target,
190193
args: Tuple[Argument, ...],
191194
kwargs: Dict[str, Argument],
192195
name: str,
193196
) -> Union[TRTTensor, Sequence[TRTTensor]]:
194-
return impl.normalization.layer_norm(
197+
return impl.activation.tanh(
195198
network,
196199
target,
197200
SourceIR.ATEN,
198201
name,
199202
args[0],
200-
args[1],
201-
args[2],
202-
args[3],
203-
args[4],
204203
)
205204

206205

207-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
208-
def aten_ops_relu(
206+
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default)
207+
def aten_ops_leaky_relu(
209208
network: TRTNetwork,
210209
target: Target,
211210
args: Tuple[Argument, ...],
212211
kwargs: Dict[str, Argument],
213212
name: str,
214213
) -> Union[TRTTensor, Sequence[TRTTensor]]:
215-
return impl.activation.relu(
214+
return impl.activation.leaky_relu(
215+
network,
216+
target,
217+
SourceIR.ATEN,
218+
name,
219+
args[0],
220+
args_bounds_check(args, 1, 0.01),
221+
)
222+
223+
224+
@dynamo_tensorrt_converter(torch.ops.aten.elu.default)
225+
def aten_ops_elu(
226+
network: TRTNetwork,
227+
target: Target,
228+
args: Tuple[Argument, ...],
229+
kwargs: Dict[str, Argument],
230+
name: str,
231+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
232+
return impl.activation.elu(
233+
network,
234+
target,
235+
SourceIR.ATEN,
236+
name,
237+
args[0],
238+
alpha=args_bounds_check(args, 1, 1.0),
239+
beta=args_bounds_check(args, 2, None),
240+
)
241+
242+
243+
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default)
244+
def aten_ops_softplus(
245+
network: TRTNetwork,
246+
target: Target,
247+
args: Tuple[Argument, ...],
248+
kwargs: Dict[str, Argument],
249+
name: str,
250+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
251+
return impl.activation.softplus(
252+
network,
253+
target,
254+
SourceIR.ATEN,
255+
name,
256+
args[0],
257+
beta=args_bounds_check(args, 1, 1),
258+
)
259+
260+
261+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
262+
def aten_ops_clip(
263+
network: TRTNetwork,
264+
target: Target,
265+
args: Tuple[Argument, ...],
266+
kwargs: Dict[str, Argument],
267+
name: str,
268+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
269+
return impl.activation.clip(
270+
network,
271+
target,
272+
SourceIR.ATEN,
273+
name,
274+
args[0],
275+
alpha=args_bounds_check(args, 1),
276+
beta=args_bounds_check(args, 2),
277+
)
278+
279+
280+
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
281+
def aten_ops_hard_sigmoid(
282+
network: TRTNetwork,
283+
target: Target,
284+
args: Tuple[Argument, ...],
285+
kwargs: Dict[str, Argument],
286+
name: str,
287+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
288+
return impl.activation.hard_sigmoid(
289+
network,
290+
target,
291+
SourceIR.ATEN,
292+
name,
293+
args[0],
294+
alpha=args_bounds_check(args, 1, 1 / 6),
295+
beta=args_bounds_check(args, 2, 1 / 2),
296+
)
297+
298+
299+
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
300+
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
301+
def aten_ops_matmul(
302+
network: TRTNetwork,
303+
target: Target,
304+
args: Tuple[Argument, ...],
305+
kwargs: Dict[str, Argument],
306+
name: str,
307+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
308+
return impl.matmul.matrix_multiply(
309+
network, target, SourceIR.ATEN, name, args[0], args[1]
310+
)
311+
312+
313+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
314+
def aten_ops_layernorm(
315+
network: TRTNetwork,
316+
target: Target,
317+
args: Tuple[Argument, ...],
318+
kwargs: Dict[str, Argument],
319+
name: str,
320+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
321+
return impl.normalization.layer_norm(
216322
network,
217323
target,
218324
SourceIR.ATEN,
219325
name,
220326
args[0],
327+
args[1],
328+
args[2],
329+
args[3],
330+
args[4],
221331
)
222332

223333

py/torch_tensorrt/dynamo/conversion/impl/activation.py

-63
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, Callable, Optional
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.fx.converters.converter_utils import (
7+
mark_as_int8_layer,
8+
set_layer_name,
9+
)
10+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
11+
12+
13+
def convert_activation(
14+
network: TRTNetwork,
15+
target: Target,
16+
source_ir: Optional[SourceIR],
17+
name: str,
18+
operation_type: trt.ActivationType,
19+
input_val: TRTTensor,
20+
alpha: Optional[Any] = None,
21+
beta: Optional[Any] = None,
22+
dyn_range_fn: Optional[Callable[[float, float], Any]] = None,
23+
) -> TRTTensor:
24+
"""
25+
Add a TensorRT Activation layer to `network`.
26+
"""
27+
if not isinstance(input_val, TRTTensor):
28+
raise RuntimeError(
29+
f"{operation_type} received input {input_val} that is not part "
30+
"of the TensorRT region!"
31+
)
32+
layer = network.add_activation(input_val, operation_type)
33+
if alpha is not None:
34+
layer.alpha = alpha
35+
if beta is not None:
36+
layer.beta = beta
37+
set_layer_name(layer, target, name, source_ir)
38+
39+
if input_val.dynamic_range is not None:
40+
dyn_range = dyn_range_fn(input_val.dynamic_range)
41+
mark_as_int8_layer(layer, dyn_range)
42+
return layer.get_output(0)

0 commit comments

Comments
 (0)