Skip to content

feat: support many elementwise dynamo converters #2263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 315 additions & 58 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_int_int_div_trt_tensor,
cast_trt_tensor,
)
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import dynamo_tensorrt_converter
Expand Down Expand Up @@ -48,58 +42,6 @@ def aten_ops_batch_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
def aten_ops_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"other": args[1],
}
# If both are TRTTensor, both are cast to float32
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
network,
kwargs_new["input"],
kwargs_new["other"],
name,
)
# If one is TRTTensor, it is cast to float32
elif isinstance(args[0], TRTTensor) and (
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
elif rounding_mode == "floor":
return acc_ops_converters.acc_ops_floor_div(
network, target, None, kwargs_new, name
)
elif rounding_mode == "trunc":
return impl.elementwise.trunc_div(
network, target, SourceIR.ATEN, name, args[0], args[1]
)
else:
raise RuntimeError(
f"Target {target} does not support rounding mode {rounding_mode}"
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)
Expand Down Expand Up @@ -1004,6 +946,321 @@ def aten_ops_isinf(
)


@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
def aten_ops_add(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
other = args[1]
alpha = kwargs.get("alpha", 1)

if alpha != 1:
other = impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
other,
alpha,
)

return impl.elementwise.add(
network,
target,
SourceIR.ATEN,
name,
args[0],
other,
)


@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
def aten_ops_mul(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
def aten_ops_max(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.max(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
def aten_ops_min(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.min(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
def aten_ops_sub(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
other = args[1]
alpha = kwargs.get("alpha", 1)

if alpha != 1:
other = impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
other,
alpha,
)

return impl.elementwise.sub(
network,
target,
SourceIR.ATEN,
name,
args[0],
other,
)


@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
def aten_ops_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
rounding_mode = kwargs.get("rounding_mode")

if rounding_mode is None:
return impl.elementwise.div(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
elif rounding_mode == "floor":
return impl.elementwise.floor_divide(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
elif rounding_mode == "trunc":
return impl.elementwise.trunc_div(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
else:
raise RuntimeError(
f"Target {target} does not support rounding mode {rounding_mode}"
)


@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
def aten_ops_pow(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.pow(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
def aten_ops_floor_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.floor_divide(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
def aten_ops_logical_and(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_and(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
def aten_ops_logical_or(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_or(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
def aten_ops_logical_xor(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_xor(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
def aten_ops_equal(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.eq(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
def aten_ops_greater(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.gt(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
def aten_ops_less(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.lt(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def conv_param_validator(conv_node: Node) -> bool:
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))

Expand Down
Loading