-
Notifications
You must be signed in to change notification settings - Fork 365
Dynamo converter cat #2343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamo converter cat #2343
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
SourceIR.ATEN, | ||
name, | ||
tensors=args[0], | ||
dim=args_bounds_check(args, 2, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to dim=args_bounds_check(args, 1, 0)
, since the default should be 0 at the args[1]
position, as here.
if dim < 0: | ||
dim = len(input[0].shape) + dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to get_positive_dim
, as here:
TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py
Lines 337 to 339 in 4cffd6e
def get_positive_dim( | |
dim: Union[int, Sequence[int]], dim_size: int | |
) -> Union[int, Tuple[int, ...]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr] | ||
raise RuntimeError( | ||
f"cat received inputs {input} that is not part " "of the TensorRT region!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of throwing an error, instead check the inputs and use get_trt_tensor
for any inputs which are not already TRTTensor
objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, #2324 can help with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using @enforce_tensor_types
here
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTNetwork, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should likely be something like Sequence[TRTTensor]
or Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]
31a2cba
to
afc9b5d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
afc9b5d
to
79e7e13
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py 2023-10-04 20:10:10.328968+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py 2023-10-04 20:13:03.260979+00:00
@@ -16,11 +16,11 @@
name: str,
input: Union[TRTTensor, Sequence[TRTTensor]],
dim: int,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
for each_input in input:
- if(not isinstance(each_input, TRTTensor)):
+ if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(each_input)
concat_layer = ctx.net.add_concatenation(input)
if dim < 0:
dim = len(input[0].shape) + dim
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: Union[TRTTensor, Sequence[TRTTensor]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]
for each_input in input: | ||
if(not isinstance(each_input, TRTTensor)): | ||
each_input = get_trt_tensor(each_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to:
trt_tensor_input = [get_trt_tensor(ctx, each_input, name + "_tensor{i}") for i, each_input in enumerate(input)]
if dim < 0: | ||
dim = len(input[0].shape) + dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py 2023-10-04 22:36:19.760280+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/cat.py 2023-10-04 22:38:58.599050+00:00
@@ -21,11 +21,11 @@
name: str,
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
dim: int,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
for each_input in input:
- if(not isinstance(each_input, TRTTensor)):
+ if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(each_input)
concat_layer = ctx.net.add_concatenation(input)
if dim < 0:
dim = get_positive_dim(dim, len(input[0].shape))
80300d3
to
94535d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current change looks good overall. A small issue could be fixed, as I commented below.
if dim < 0: | ||
dim = get_positive_dim(dim, len(input[0].shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since get_positive_dim
will internally check if dim
is less than 0, the condition could be deleted.
for each_input in input: | ||
if not isinstance(each_input, TRTTensor): | ||
each_input = get_trt_tensor(each_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop will not save the get_trt_tensor
objects anywhere, since the each_input
object gets reset on every loop iteration. Consider replacing with something like:
trt_inputs = []
for i, each_input in enumerate(input):
each_input = get_trt_tensor(ctx, each_input, name + "_tensor{i}")
trt_inputs.append(each_input)
...
# Replace all uses of input with trt_inputs
57544fc
to
1a6e814
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
1a6e814
to
83e4f3c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM.
Verified tests passing locally. |
This PR moves the aten::cat implementation to the dynamo impl library.