Skip to content

Commit 0047b3d

Browse files
committed
add converter registration
1 parent a243274 commit 0047b3d

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,14 @@ def aten_ops_softmax(
357357

358358
@dynamo_tensorrt_converter(
359359
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
360-
)
360+
) # type: ignore[misc]
361361
@dynamo_tensorrt_converter(
362362
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
363-
)
363+
) # type: ignore[misc]
364364
@dynamo_tensorrt_converter(
365365
torch.ops.aten.split_with_sizes.default,
366366
capability_validator=dynamic_unsupported_with_args([1]),
367-
)
367+
) # type: ignore[misc]
368368
def aten_ops_split(
369369
network: TRTNetwork,
370370
target: Target,
@@ -1378,3 +1378,22 @@ def aten_ops_linear(
13781378
weight=args[1],
13791379
bias=args_bounds_check(args, 2, None),
13801380
)
1381+
1382+
1383+
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
1384+
def aten_ops_argmax(
1385+
network: TRTNetwork,
1386+
target: Target,
1387+
args: Tuple[Argument, ...],
1388+
kwargs: Dict[str, Argument],
1389+
name: str,
1390+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1391+
return impl.argmax.argmax(
1392+
network,
1393+
target,
1394+
SourceIR.ATEN,
1395+
name,
1396+
input=args[0],
1397+
dim=args_bounds_check(args, 1),
1398+
keep_dim=args_bounds_check(args, 2),
1399+
)

0 commit comments

Comments
 (0)