Skip to content

Commit 55d5c03

Browse files
[mlir][vector] Fix crash in vector.extract folder (#95912)
Fix a bug in the `vector.extract` folder when the vector type is 0-d.
1 parent e42a4c7 commit 55d5c03

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1883,7 +1883,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
18831883
}
18841884

18851885
OpFoldResult ExtractOp::fold(FoldAdaptor) {
1886-
if (getNumIndices() == 0)
1886+
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1887+
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1888+
// mismatch).
1889+
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
18871890
return getVector();
18881891
if (succeeded(foldExtractOpFromExtractChain(*this)))
18891892
return getResult();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,3 +2593,14 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
25932593
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
25942594
return %0 : vector<12xi32>
25952595
}
2596+
2597+
// -----
2598+
2599+
// CHECK-LABEL: func @extract_from_0d_regression(
2600+
// CHECK-SAME: %[[v:.*]]: vector<f32>)
2601+
// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
2602+
// CHECK: return %[[extract]]
2603+
func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
2604+
%0 = vector.extract %v[] : f32 from vector<f32>
2605+
return %0 : f32
2606+
}

0 commit comments

Comments
 (0)