Skip to content

Commit

Permalink
[mlir][tensor] Fix tensor.reshape canonicalization (#90141)
Browse files Browse the repository at this point in the history
Canonicalization defaulted to replacement when the input dims were from
unknown source. This is obviously incorrect. Tweaked and included test
to prevent future issue.
  • Loading branch information
rsuderman authored Apr 26, 2024
1 parent c2170a3 commit 593f6fd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
cst.has_value() && cst.value() == static_cast<int64_t>(id);
continue;
}

dynamicNoop = false;
break;
}

if (dynamicNoop)
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2431,6 +2431,15 @@ func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
return %reshape : tensor<?x?xi32>
}

// -----

// CHECK-LABEL: @reshape_nofold_2d_ins
func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
%ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
// CHECK: tensor.reshape
%reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
return %reshape : tensor<?x?xi32>
}

// -----

Expand Down

0 comments on commit 593f6fd

Please sign in to comment.