Skip to content

fix: Add automatic type promotion for FX elementwise ops #2055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,38 @@ def create_constant(
return constant.get_output(0)


def cast_trt_tensor(
network: TRTNetwork,
input_val: TRTTensor,
dtype: TRTDataType,
name: str,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype

Adds an Identity layer to the network which performs the conversion

Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
name (str): Name of the calling layer

Returns:
A TensorRT ITensor which has been casted to the specified dtype
"""
#
if input_val.dtype != dtype:
identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, dtype)
identity_layer.name = (
f"Cast ITensor {input_val.name} from {input_val.dtype} to {dtype} - {name}"
)
return identity_layer.get_output(0)
else:
return input_val


def get_trt_tensor(
network: TRTNetwork,
input_val: Any,
Expand Down
13 changes: 13 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
broadcast,
squeeze_left,
get_trt_tensor,
cast_trt_tensor,
)


Expand Down Expand Up @@ -52,6 +53,7 @@ def convert_binary_elementwise(
introduce constant via .size() op. Other scenario should be const folded first.
If any operand is not a trt tensor, we make it a trt constant layer while preserve
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
We also promote the types of the two tensors to avoid dtype errors in TRT.

Limitation:
If we are using implicit batch dim mode, the operand that is not a trt
Expand Down Expand Up @@ -126,6 +128,17 @@ def convert_binary_elementwise(
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

promoted_type = torch.promote_types(
unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH),
unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH),
)
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)

if trt_promoted_type != lhs_val.dtype:
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
if trt_promoted_type != rhs_val.dtype:
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ def forward(self, x):
inputs = [torch.rand(1, 1) + 1]
self.run_test(m, inputs, expected_ops={expected_op})

@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
def test_elementwise_ops_mismatched_dtypes(
self, name, orig_op: Callable, expected_op
):
class TestModule(nn.Module):
def __init__(self, orig_op):
super().__init__()
self.orig_op = orig_op

def forward(self, x):
return self.orig_op(x.int(), x)

m = TestModule(orig_op)
# Avoid dividing by 0.
inputs = [2 * torch.rand(1, 1, dtype=torch.float) + 1]
self.run_test(m, inputs, expected_ops={expected_op})

@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
def test_elementwise_ops_with_one_constant(
self, name, orig_op: Callable, expected_op
Expand Down