Skip to content

Dynamo converter cat #2343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 6, 2023
Merged

Dynamo converter cat #2343

merged 4 commits into from
Oct 6, 2023

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Sep 25, 2023

This PR moves the aten::cat implementation to the dynamo impl library.

@apbose apbose requested a review from gs-olive September 25, 2023 19:58
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 25, 2023
@github-actions github-actions bot requested a review from narendasan September 25, 2023 19:58
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

SourceIR.ATEN,
name,
tensors=args[0],
dim=args_bounds_check(args, 2, 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to dim=args_bounds_check(args, 1, 0), since the default should be 0 at the args[1] position, as here.

Comment on lines 23 to 24
if dim < 0:
dim = len(input[0].shape) + dim
Copy link
Collaborator

@gs-olive gs-olive Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to get_positive_dim, as here:

def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment

Comment on lines 18 to 21
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
raise RuntimeError(
f"cat received inputs {input} that is not part " "of the TensorRT region!"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of throwing an error, instead check the inputs and use get_trt_tensor for any inputs which are not already TRTTensor objects.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, #2324 can help with this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using @enforce_tensor_types here

target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTNetwork,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should likely be something like Sequence[TRTTensor] or Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose force-pushed the dynamo_converter_cat branch from afc9b5d to 79e7e13 Compare October 4, 2023 20:06
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py	2023-10-04 20:10:10.328968+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py	2023-10-04 20:13:03.260979+00:00
@@ -16,11 +16,11 @@
    name: str,
    input: Union[TRTTensor, Sequence[TRTTensor]],
    dim: int,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    for each_input in input:
-        if(not isinstance(each_input, TRTTensor)):
+        if not isinstance(each_input, TRTTensor):
            each_input = get_trt_tensor(each_input)
    concat_layer = ctx.net.add_concatenation(input)
    if dim < 0:
        dim = len(input[0].shape) + dim

target: Target,
source_ir: Optional[SourceIR],
name: str,
input: Union[TRTTensor, Sequence[TRTTensor]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]

Comment on lines 20 to 22
for each_input in input:
if(not isinstance(each_input, TRTTensor)):
each_input = get_trt_tensor(each_input)
Copy link
Collaborator

@gs-olive gs-olive Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to:

trt_tensor_input = [get_trt_tensor(ctx, each_input, name + "_tensor{i}") for i, each_input in enumerate(input)]

Comment on lines 23 to 24
if dim < 0:
dim = len(input[0].shape) + dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py	2023-10-04 22:36:19.760280+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py	2023-10-04 22:38:58.599050+00:00
@@ -21,11 +21,11 @@
    name: str,
    input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
    dim: int,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    for each_input in input:
-        if(not isinstance(each_input, TRTTensor)):
+        if not isinstance(each_input, TRTTensor):
            each_input = get_trt_tensor(each_input)
    concat_layer = ctx.net.add_concatenation(input)
    if dim < 0:
        dim = get_positive_dim(dim, len(input[0].shape))

@apbose apbose force-pushed the dynamo_converter_cat branch from 80300d3 to 94535d2 Compare October 4, 2023 22:42
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current change looks good overall. A small issue could be fixed, as I commented below.

Comment on lines 29 to 30
if dim < 0:
dim = get_positive_dim(dim, len(input[0].shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since get_positive_dim will internally check if dim is less than 0, the condition could be deleted.

Comment on lines 25 to 27
for each_input in input:
if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(each_input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop will not save the get_trt_tensor objects anywhere, since the each_input object gets reset on every loop iteration. Consider replacing with something like:

    trt_inputs = []
    for i, each_input in enumerate(input):
        each_input = get_trt_tensor(ctx, each_input, name + "_tensor{i}")
        trt_inputs.append(each_input)
...
# Replace all uses of input with trt_inputs

@apbose apbose force-pushed the dynamo_converter_cat branch 2 times, most recently from 57544fc to 1a6e814 Compare October 5, 2023 17:29
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose requested review from zewenli98 and gs-olive October 5, 2023 17:32
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@apbose apbose force-pushed the dynamo_converter_cat branch from 1a6e814 to 83e4f3c Compare October 5, 2023 22:21
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM.

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 6, 2023

Verified tests passing locally.

@gs-olive gs-olive merged commit 83176fe into main Oct 6, 2023
@gs-olive gs-olive deleted the dynamo_converter_cat branch October 6, 2023 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths priority: high
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants