Skip to content

feat: support many elementwise dynamo converters #2263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023

Conversation

zewenli98
Copy link
Collaborator

Description

Support many elementwise dynamo converters, including add, mul, maximum, minimum, sub, div(already implemented), pow, floor_divide, logical_and, logical_or, logical_xor, eq, gt, lt

Fixes #2208

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 25, 2023
@github-actions github-actions bot requested a review from gs-olive August 25, 2023 00:38
@zewenli98
Copy link
Collaborator Author

zewenli98 commented Aug 29, 2023

@gs-olive The error seems not to be related to this PR. Could you take a look? Thanks!

@gs-olive
Copy link
Collaborator

@zewenli98 - taking a look now! Also, could you rebase this one to main to resolve the merge conflicts?

@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from 8dfba80 to a187bb9 Compare August 29, 2023 23:24
@zewenli98
Copy link
Collaborator Author

zewenli98 commented Aug 29, 2023

@gs-olive Rebased! thanks!

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments below to fix the Dynamo errors appearing for this PR.

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two additional changes are needed to fix this CI issue. In convert_binary_elementwise, we need to change Frameworks.TORCH to Frameworks.NUMPY:

if isinstance(lhs_val, TRTTensor):
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
is_rhs_trt_tensor = True

As well as torch.tensor to np.array:
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)

The above are fixed in #2265, so let me try to merge that + rebase to resolve the above

@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from a187bb9 to da66673 Compare August 31, 2023 22:07
@zewenli98
Copy link
Collaborator Author

@gs-olive Thanks for the suggestions and detailed explanations! I updated and rebased!

@zewenli98 zewenli98 requested a review from gs-olive August 31, 2023 22:10
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look great - added some comments on schemas, converter support, and the div, add, and sub operators, which have special cases

@zewenli98
Copy link
Collaborator Author

@gs-olive Thanks for the review! Resolved issues above.

@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from a4075ca to b51c121 Compare September 2, 2023 00:27
@zewenli98 zewenli98 requested a review from gs-olive September 5, 2023 22:21
@gs-olive
Copy link
Collaborator

gs-olive commented Sep 5, 2023

@zewenli98 - Seeing the following error in both test failures for this PR:

[TRT] [E] 4: [network.cpp::inferOutputTypes::2063] Error Code 4: Internal Error (Output tensor output0 of type Int32 produced from output of incompatible type Float)

I assume it is because the forward function in those tests is like so:

        class Tensor0DInput(torch.nn.Module):
            def forward(self, x):
                return x * 7

The input tensor x is an integer Tensor, so Torch provides an Int32 output (meaning we also expect an Int32 output), but due to the new converters, we are casting all integers for these elementwise ops to floats. I know that some elementwise converters require float inputs, but do all of the converters require this? As in - could aten.mul work without this cast?

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 5, 2023

For instance - in TorchScript, we don't have float casts for mul or add, only div.

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New changes look great! See comments above on narrowing the usage of float casting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive Thanks for your review! I just found a similar error. For some ops, such as eq, we have to specify output_dtypes=[torch.bool] in tests, otherwise, we will get error: [TRT] [E] 4: [network.cpp::inferOutputTypes::2063] Error Code 4: Internal Error (Output tensor output0 of type Int32 produced from output of incompatible type Bool).

However, the weird is that if we don't specify output_dtypes=[torch.bool], instead, just add output.dtype (do nothing) before return output in convert_binary_elementwise. The error disappeared and it can pass the test. So, I was wondering if this is kind of like lazy mode, requiring running output.dtype to get right type?

Copy link
Collaborator

@gs-olive gs-olive Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is definitely strange, calling output.dtype shouldn't have an effect on its own, to my knowledge. The way we currently determine output data types in our torch.compile backend is we run the graph in Torch and see the output types, then we set the TRT engine outputs according to this. This becomes problematic if TRT requires a float cast where Torch does not (for instance, on Int, Int adds or multiplies). For this reason, we do not need float casting in add or mul, as there is currently not float casting in our TorchScript path. On the other hand, we do need bool type specification for eq, as you observed, since the output could otherwise be Int32.

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the cast_int_int_div_trt_tensor function and cross-validate with the C++ implementations to verify where casts are needed for the elementwise operators. I think very few of these need trt_cast_int_to_float. Please let me know if there are any resources that indicate otherwise.

Comment on lines 256 to 260
if isinstance(lhs_val, TRTTensor):
lhs_val = trt_cast_int_to_float(network, name, lhs_val)

if isinstance(rhs_val, TRTTensor):
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, unless there is reason to cast the add operator to a float.

Comment on lines 275 to 279
if isinstance(lhs_val, TRTTensor):
lhs_val = trt_cast_int_to_float(network, name, lhs_val)

if isinstance(rhs_val, TRTTensor):
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be removed

@zewenli98
Copy link
Collaborator Author

yes! Thanks for the details! I just wanted to check if adding output.dtype is useful (control variables) in this commit. I'll let you know after resolving all the issues!

@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from 1e65a53 to 25c1ff4 Compare September 7, 2023 20:47
@zewenli98
Copy link
Collaborator Author

@gs-olive Let me give some explanations here. At first, I referred to this doc https://docs.nvidia.com/deeplearning/tensorrt/operators/docs/ElementWise.html#data-types where it says int32 is not supported for all elementwise ops. However, after some tests, I found only div and pow don't support int32. Accordingly, I made these changes in this commit. But it looks like there's an error about pip in the smoke test.

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 7, 2023

@zewenli98 Understood - thanks for this resource. In general, I would recommend using the C++ TRT documentation for input restrictions. This link for the C++ API suggests that POW is restricted to float, half, and bool, and AND, OR, and XOR only support bool, whereas all other ops can also have int32 (though div is an outlier, as you mentioned). Regarding the pip error, it is resolved in #2298 which will be merged soon.

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 7, 2023

Thanks for the changes - #2298 was just merged, so could you rebase to main

@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from 8dd9a33 to 6175423 Compare September 7, 2023 23:24
@zewenli98
Copy link
Collaborator Author

Rebased! Thanks George!

@zewenli98 zewenli98 requested a review from gs-olive September 8, 2023 00:11
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Well tested and very useful utilities + converters added. This will greatly improve our operator coverage and correctness - thanks @zewenli98!

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 8, 2023

@zewenli98 - Could you rebase to the latest main for TorchScript testing

add output_dtypes in test

add util func and fix bugs

add overloads, update tests, and fix a bug

fix arg bug

delete int2float conversion for some ops

update type conversion
@zewenli98 zewenli98 force-pushed the elementwise_dynamo_converters branch from 6175423 to 07d823b Compare September 8, 2023 18:11
@gs-olive gs-olive merged commit 40f8064 into pytorch:main Sep 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose IElementWiseLayer in dynamo.conversion.impl
3 participants