Skip to content

feat: support bmm converter in dynamo #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2023
Merged

feat: support bmm converter in dynamo #2248

merged 1 commit into from
Sep 28, 2023

Conversation

bowang007
Copy link
Collaborator

Description

Support bmm converter for dynamo

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@bowang007 bowang007 requested a review from narendasan August 21, 2023 01:56
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 21, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-21 01:56:33.820650+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-21 01:59:17.826690+00:00
@@ -180,21 +180,21 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return impl.matmul.matrix_multiply(
        network, target, SourceIR.ATEN, name, args[0], args[1]
    )

+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
def aten_ops_bmm(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.matmul.bmm(
-        network, target, SourceIR.ATEN, name, args[0], args[1]
-    )
+    return impl.matmul.bmm(network, target, SourceIR.ATEN, name, args[0], args[1])
+

@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
def aten_ops_layernorm(
    network: TRTNetwork,
    target: Target,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-08-21 01:56:33.824650+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-08-21 01:59:17.915372+00:00
@@ -47,10 +47,11 @@
    )
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name)
    return layer.get_output(0)

+
def bmm(
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -73,12 +74,12 @@
    if len(input.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got ")

    if len(other.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got")
-    
-    if (input.shape[0] != other.shape[0]):
+
+    if input.shape[0] != other.shape[0]:
        raise RuntimeError("expected input tensors to have same batch size.")
-    
+
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name)
-    return layer.get_output(0)
\ No newline at end of file
+    return layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-08-21 01:56:33.840650+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-08-21 01:59:21.411402+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
from torch_tensorrt import Input
+

class TestBmmConverter(DispatchTestCase):
    @parameterized.expand(
        [
            ("10_3_5", (10, 3, 4), (9, 4, 5)),
@@ -13,21 +14,20 @@
    )
    def test_bmm(self, _, input_shape, mat2_shape):
        class BMM(nn.Module):
            def __init__(self):
                super().__init__()
-                
+
            def forward(self, input, mat2):
                return torch.bmm(input, mat2)
-            
+
        inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]

-
        self.run_test(
-            BMM(), 
+            BMM(),
            inputs,
            expected_ops={},
        )
-        
+

if __name__ == "__main__":
    run_tests()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-21 01:56:46.017488+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-08-21 01:59:42.153437+00:00
@@ -180,21 +180,21 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return impl.matmul.matrix_multiply(
        network, target, SourceIR.ATEN, name, args[0], args[1]
    )

+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
def aten_ops_bmm(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.matmul.bmm(
-        network, target, SourceIR.ATEN, name, args[0], args[1]
-    )
+    return impl.matmul.bmm(network, target, SourceIR.ATEN, name, args[0], args[1])
+

@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
def aten_ops_layernorm(
    network: TRTNetwork,
    target: Target,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-08-21 01:56:46.017488+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-08-21 01:59:42.159876+00:00
@@ -47,10 +47,11 @@
    )
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name)
    return layer.get_output(0)

+
def bmm(
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -73,12 +74,12 @@
    if len(input.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got ")

    if len(other.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got")
-    
-    if (input.shape[0] != other.shape[0]):
+
+    if input.shape[0] != other.shape[0]:
        raise RuntimeError("expected input tensors to have same batch size.")
-    
+
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name)
-    return layer.get_output(0)
\ No newline at end of file
+    return layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-08-21 01:56:46.041488+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-08-21 01:59:47.252864+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
from torch_tensorrt import Input
+

class TestBmmConverter(DispatchTestCase):
    @parameterized.expand(
        [
            ("10_3_5", (10, 3, 4), (9, 4, 5)),
@@ -13,21 +14,20 @@
    )
    def test_bmm(self, _, input_shape, mat2_shape):
        class BMM(nn.Module):
            def __init__(self):
                super().__init__()
-                
+
            def forward(self, input, mat2):
                return torch.bmm(input, mat2)
-            
+
        inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]

-
        self.run_test(
-            BMM(), 
+            BMM(),
            inputs,
            expected_ops={},
        )
-        
+

if __name__ == "__main__":
    run_tests()

@narendasan
Copy link
Collaborator

@bowang007 lint the PR

@gs-olive
Copy link
Collaborator

To remove:

@gs-olive
Copy link
Collaborator

Consider switching to True default:

disable_passes: bool = False,

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-09-08 18:11:55.418367+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/matmul.py	2023-09-08 18:14:17.787932+00:00
@@ -47,10 +47,11 @@
    )
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name, source_ir)
    return layer.get_output(0)

+
def bmm(
    network: TRTNetwork,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -73,12 +74,12 @@
    if len(input.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got ")

    if len(other.shape) != 3:
        raise RuntimeError(f"Expected 3-dimensional tensor, but got")
-    
-    if (input.shape[0] != other.shape[0]):
+
+    if input.shape[0] != other.shape[0]:
        raise RuntimeError("expected input tensors to have same batch size.")
-    
+
    layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
    set_layer_name(layer, target, name)
-    return layer.get_output(0)
\ No newline at end of file
+    return layer.get_output(0)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-09-08 18:11:55.418367+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-09-08 18:14:17.923968+00:00
@@ -313,21 +313,21 @@
        name,
        args[0],
        args[1],
    )

+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
def aten_ops_bmm(
    network: TRTNetwork,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.matmul.bmm(
-        network, target, SourceIR.ATEN, name, args[0], args[1]
-    )
+    return impl.matmul.bmm(network, target, SourceIR.ATEN, name, args[0], args[1])
+

@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)  # type: ignore[misc]
def aten_ops_layernorm(
    network: TRTNetwork,
    target: Target,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-09-08 18:11:55.434368+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_bmm.py	2023-09-08 18:14:21.664385+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
from torch_tensorrt import Input
+

class TestBmmConverter(DispatchTestCase):
    @parameterized.expand(
        [
            ("10_3_5", (10, 3, 4), (9, 4, 5)),
@@ -13,21 +14,20 @@
    )
    def test_bmm(self, _, input_shape, mat2_shape):
        class BMM(nn.Module):
            def __init__(self):
                super().__init__()
-                
+
            def forward(self, input, mat2):
                return torch.bmm(input, mat2)
-            
+
        inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]

-
        self.run_test(
-            BMM(), 
+            BMM(),
            inputs,
            expected_ops={},
        )
-        
+

if __name__ == "__main__":
    run_tests()

@bowang007 bowang007 changed the title converter: support bmm converter for dynamo feat: support bmm converter in dynamo Sep 22, 2023
@bowang007 bowang007 requested a review from gs-olive September 22, 2023 22:30
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, pending CI validation for the added test case!

@bowang007
Copy link
Collaborator Author

Looks good to me, pending CI validation for the added test case!

Is CI fixed now?

@gs-olive
Copy link
Collaborator

It should be working, but I see there is some sort of Torch version mismatch error. I'm not yet sure what's the cause - it seems to be happening across all the recent CI runs, since about 3pm today.



class TestBmmConverter(DispatchTestCase):
@parameterized.expand(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As such everything looks good. But can more test cases be added with varying dimensions?

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI issues resolved - should just need the additional test cases, and then can merge

Signed-off-by: Bo Wang <bowa@nvidia.com>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@apbose apbose left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants