-
Notifications
You must be signed in to change notification settings - Fork 365
Moving normalization core to impl - softmax (FX Converter Refactor [12/N]) <Target: converter_reorg_elementwise> #1909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- py/torch_tensorrt/fx/converters/aten_ops_converters.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/aten_ops_converters.py 2023-05-10 06:06:48.425711 +0000
@@ -359,17 +359,11 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return softmax(
- network,
- target,
- SourceIR.ATEN,
- name,
- kwargs["input"],
- kwargs["dim"])
+ return softmax(network, target, SourceIR.ATEN, name, kwargs["input"], kwargs["dim"])
@tensorrt_converter(torch.ops.aten.cat.default)
def aten_ops_cat(
network: TRTNetwork,
--- py/torch_tensorrt/fx/converters/impl/normalization/__init__.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/impl/normalization/__init__.py 2023-05-10 06:06:48.631521 +0000
@@ -1 +1 @@
-from .ops import *
\ No newline at end of file
+from .ops import *
--- py/torch_tensorrt/fx/converters/impl/normalization/ops.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/impl/normalization/ops.py 2023-05-10 06:06:48.739746 +0000
@@ -12,20 +12,21 @@
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
set_layer_name,
- get_positive_dim
+ get_positive_dim,
)
+
def softmax(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- dim: Optional[Any] = None
+ dim: Optional[Any] = None,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
if not isinstance(input, TRTTensor):
raise RuntimeError(
@@ -53,6 +54,5 @@
layer = network.add_softmax(input)
layer.axes = 1 << dim
set_layer_name(layer, target, name)
return layer.get_output(0)
-
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py 2023-05-10 06:06:51.503546 +0000
@@ -857,18 +857,11 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return softmax(
- network,
- target,
- SourceIR.ACC,
- name,
- kwargs["input"],
- kwargs["dim"]
- )
+ return softmax(network, target, SourceIR.ACC, name, kwargs["input"], kwargs["dim"])
@tensorrt_converter(acc_ops.tile)
def acc_ops_tile(
network: TRTNetwork,
--- py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py 2023-05-10 06:06:19.174185 +0000
+++ py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py 2023-05-10 06:06:51.919309 +0000
@@ -39,6 +39,6 @@
TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
)
if __name__ == "__main__":
- run_tests()
\ No newline at end of file
+ run_tests()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
0d4e6d7
to
9611d67
Compare
2769db2
to
2f23743
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9611d67
to
93846ed
Compare
Signed-off-by: Naren Dasan <naren@narendasan.com> new file: ../converters/impl/unary/base.py
93846ed
to
83986cd
Compare
… <Target: converter_reorg_elementwise> (#1905)
2f23743
to
f423be1
Compare
softmax linting error fix
f423be1
to
75b1a2a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
546f975
to
c8a9559
Compare
18e503b
to
2caac76
Compare
Closed with #2070 merged with main |
No description provided.