Skip to content

fix: Repair broadcasting utility for aten.where #2228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,23 @@ def broadcastable(
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
b_shape = tuple(b.shape)

# check from the trailing
diff = len(a_shape) - len(b_shape)
if diff == 0:

# Validate tensors have same rank and shape
if diff == 0 and all(a_shape[i] == b_shape[i] for i in range(len(a_shape))):
return True

# Left-pad the shorter dimension with ones
if diff > 0:
max = len(a_shape)
min = len(b_shape)
greater_tensor = a_shape
lesser_tensor = b_shape
elif diff < 0:
max = len(b_shape)
min = len(a_shape)
greater_tensor = b_shape
lesser_tensor = a_shape
j = min - 1
for i in range(max - 1, diff - 1, -1):
if not (
greater_tensor[i] != lesser_tensor[j]
and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
):
b_shape = (1,) * abs(diff) + b_shape
else:
a_shape = (1,) * abs(diff) + a_shape

# Validate one of the following conditions for broadcastability per-dimension
# 1. Equal number of dimensions or 2. Dimension has shape 1
for i in range(len(a_shape)):
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
return False
return True
16 changes: 15 additions & 1 deletion tests/py/dynamo/converters/test_where_aten.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase


class TestWhereConverter(DispatchTestCase):
Expand All @@ -28,6 +28,20 @@ def forward(self, condition, x, y):
expected_ops={torch.ops.aten.where.self},
)

def test_0D_input(self):
class Where(nn.Module):
def forward(self, condition, x, y):
return torch.where(condition, x, y)

inputX = torch.randn((5, 6, 7, 1, 3))
inputOther = torch.tensor(8.0, dtype=torch.float)
condition = inputX < 0
self.run_test(
Where(),
(condition, inputX, inputOther),
expected_ops={torch.ops.aten.where.self},
)


if __name__ == "__main__":
run_tests()