-
Notifications
You must be signed in to change notification settings - Fork 365
feat: support many elementwise dynamo converters #2263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@gs-olive The error seems not to be related to this PR. Could you take a look? Thanks! |
@zewenli98 - taking a look now! Also, could you rebase this one to |
8dfba80
to
a187bb9
Compare
@gs-olive Rebased! thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments below to fix the Dynamo errors appearing for this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two additional changes are needed to fix this CI issue. In convert_binary_elementwise
, we need to change Frameworks.TORCH
to Frameworks.NUMPY
:
TensorRT/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Lines 77 to 82 in e49ef6d
if isinstance(lhs_val, TRTTensor): | |
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) | |
is_lhs_trt_tensor = True | |
if isinstance(rhs_val, TRTTensor): | |
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) | |
is_rhs_trt_tensor = True |
As well as
torch.tensor
to np.array
:TensorRT/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Lines 105 to 108 in e49ef6d
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): | |
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) | |
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): | |
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) |
The above are fixed in #2265, so let me try to merge that + rebase to resolve the above
a187bb9
to
da66673
Compare
@gs-olive Thanks for the suggestions and detailed explanations! I updated and rebased! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look great - added some comments on schemas, converter support, and the div
, add
, and sub
operators, which have special cases
@gs-olive Thanks for the review! Resolved issues above. |
a4075ca
to
b51c121
Compare
@zewenli98 - Seeing the following error in both test failures for this PR: [TRT] [E] 4: [network.cpp::inferOutputTypes::2063] Error Code 4: Internal Error (Output tensor output0 of type Int32 produced from output of incompatible type Float) I assume it is because the forward function in those tests is like so: class Tensor0DInput(torch.nn.Module):
def forward(self, x):
return x * 7 The input tensor |
For instance - in TorchScript, we don't have float casts for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New changes look great! See comments above on narrowing the usage of float casting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive Thanks for your review! I just found a similar error. For some ops, such as eq
, we have to specify output_dtypes=[torch.bool]
in tests, otherwise, we will get error: [TRT] [E] 4: [network.cpp::inferOutputTypes::2063] Error Code 4: Internal Error (Output tensor output0 of type Int32 produced from output of incompatible type Bool)
.
However, the weird is that if we don't specify output_dtypes=[torch.bool]
, instead, just add output.dtype
(do nothing) before return output
in convert_binary_elementwise
. The error disappeared and it can pass the test. So, I was wondering if this is kind of like lazy mode, requiring running output.dtype
to get right type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is definitely strange, calling output.dtype
shouldn't have an effect on its own, to my knowledge. The way we currently determine output data types in our torch.compile backend is we run the graph in Torch and see the output types, then we set the TRT engine outputs according to this. This becomes problematic if TRT requires a float cast where Torch does not (for instance, on Int, Int adds or multiplies). For this reason, we do not need float casting in add or mul, as there is currently not float casting in our TorchScript path. On the other hand, we do need bool type specification for eq, as you observed, since the output could otherwise be Int32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the cast_int_int_div_trt_tensor
function and cross-validate with the C++ implementations to verify where casts are needed for the elementwise operators. I think very few of these need trt_cast_int_to_float
. Please let me know if there are any resources that indicate otherwise.
if isinstance(lhs_val, TRTTensor): | ||
lhs_val = trt_cast_int_to_float(network, name, lhs_val) | ||
|
||
if isinstance(rhs_val, TRTTensor): | ||
rhs_val = trt_cast_int_to_float(network, name, rhs_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed, unless there is reason to cast the add
operator to a float.
if isinstance(lhs_val, TRTTensor): | ||
lhs_val = trt_cast_int_to_float(network, name, lhs_val) | ||
|
||
if isinstance(rhs_val, TRTTensor): | ||
rhs_val = trt_cast_int_to_float(network, name, rhs_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can also be removed
yes! Thanks for the details! I just wanted to check if adding |
1e65a53
to
25c1ff4
Compare
@gs-olive Let me give some explanations here. At first, I referred to this doc https://docs.nvidia.com/deeplearning/tensorrt/operators/docs/ElementWise.html#data-types where it says |
@zewenli98 Understood - thanks for this resource. In general, I would recommend using the C++ TRT documentation for input restrictions. This link for the C++ API suggests that |
Thanks for the changes - #2298 was just merged, so could you rebase to main |
8dd9a33
to
6175423
Compare
Rebased! Thanks George! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Well tested and very useful utilities + converters added. This will greatly improve our operator coverage and correctness - thanks @zewenli98!
@zewenli98 - Could you rebase to the latest |
add output_dtypes in test add util func and fix bugs add overloads, update tests, and fix a bug fix arg bug delete int2float conversion for some ops update type conversion
6175423
to
07d823b
Compare
Description
Support many elementwise dynamo converters, including
add
,mul
,maximum
,minimum
,sub
,div
(already implemented),pow
,floor_divide
,logical_and
,logical_or
,logical_xor
,eq
,gt
,lt
Fixes #2208
Type of change
Checklist: