Open
Description
🐛 Describe the bug
Subtraction is behaving weirdly when we have to broadcast both x and y.
Repro code snippet:
import torch
from torch.export import Dim
from executorch.exir import to_edge
from executorch.runtime import Runtime
class Model(torch.nn.Module):
def forward(self, x, y):
return x - y
dim = Dim("dim", min=1, max=1024)
ep = torch.export.export(
Model(), (torch.ones(1, 4), torch.ones(4, 1)), dynamic_shapes=({1: dim}, {0: dim})
)
pte = to_edge(ep).to_executorch().buffer
runtime = Runtime.get()
program = runtime.load_program(pte)
method = program.load_method("forward")
# run sub with (1, 2) and (2, 1) first, result should be zeros(2, 2)
method.execute((torch.ones(1, 2), torch.ones(2, 1)))
# run sub with (1, 4) and (4, 1), result should be zeros(4, 4) but actually getting zeros(1, 4)
method.execute((torch.ones(1, 4), torch.ones(4, 1)))
"""
Actual
[tensor([[0., 0., 0., 0.]])]
Expected
[tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])]
"""
Versions
Latest
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage