Skip to content

Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> #1905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,6 +301,42 @@ def aten_ops_relu(
)


@tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return activation.relu(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.rsqrt.default)
def aten_ops_rsqrt(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return rsqrt(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.sub.Tensor)
def aten_ops_sub(
network: TRTNetwork,
Expand Down
30 changes: 30 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,33 @@ def trunc_div(
)

return output


def rsqrt(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:

sqrt_trt_output = convert_unary(
network,
target,
source_ir,
f"{name}_sqrt",
trt.UnaryOperation.SQRT,
input,
)

output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_output",
trt.ElementWiseOperation.DIV,
1,
sqrt_trt_output,
)

return output
29 changes: 29 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestRSqrtConverter(DispatchTestCase):
@parameterized.expand(
[
("2d_dim_alpha", (2, 1), 2),
("3d_dim_alpha", (2, 1, 2), 2),
]
)
def test_rsqrt(self, _, x, alpha):
class rsqrt(nn.Module):
def forward(self, input):
return torch.rsqrt(input)

inputs = [torch.randn(x) + 1]
self.run_test(
rsqrt(),
inputs,
expected_ops={torch.ops.aten.rsqrt.default},
)


if __name__ == "__main__":
run_tests()