Skip to content

Commit

Permalink
Remove folder from AtenStackOp for single element list inputs (llvm…
Browse files Browse the repository at this point in the history
…#2626)

`AtenStackOp` defines this folder for list operand containing single
element:
```
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
  auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
  if (!list || !list->hasOneUse() || list.getElements().size() != 1)
    return nullptr;
  return list.getElements()[0];
}
```
However, unlike `AtenCatOp`, `AtenStackOp` cannot be folded away for
single element list operand because the result from a stack operation
contains an additional dimension (of size 1, like expand_shape).

This PR removes the `AtenStackOp::fold` method, and adds an e2e test for
single element list input case, which fails on current `main` as
follows:
```
Unexpected outcome summary: (linalg)                                                                                                                                                                   
                                                                                                                                                                                                       
****** Failed tests - 1 tests                                                                                                                                                                          
    FAIL - "TensorsStackSingleElementListModule_basic"                                                                                                                                                 
        @ trace item #0 - call to "forward"                                                                                                                                                            
        @ output of call to "forward"                                                                                                                                                                  
        ERROR: shape (torch.Size([10, 32])) is not equal to golden shape (torch.Size([10, 1, 32]))     
```
Thanks Chris Lalau Keraly for the bug report.
  • Loading branch information
sjain-stanford authored Dec 11, 2023
1 parent 0b4422a commit 7acabaf
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
1 change: 0 additions & 1 deletion include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11632,7 +11632,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
Expand Down
11 changes: 0 additions & 11 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2444,17 +2444,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
return list.getElements()[0];
}

//===----------------------------------------------------------------------===//
// AtenStackOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
return list.getElements()[0];
}

//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def emit_with_mutating_variants(key, **kwargs):

# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
Expand Down
22 changes: 22 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,28 @@ def TensorsStackModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorsStackSingleElementListModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.stack([x], dim=1)


@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule())
def TensorsStackSingleElementListModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 32))


# ==============================================================================


class TensorsStackNegativeDimModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 7acabaf

Please sign in to comment.