Skip to content

Commit cfaef9d

Browse files
committed
ignore non-byte-aligned types
1 parent 1b5b408 commit cfaef9d

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,14 @@ class ExtractOpFromElementwise final
10471047
}
10481048
};
10491049

1050+
static bool isSupportedMemSinkElementType(Type type) {
1051+
if (isa<IndexType>(type))
1052+
return true;
1053+
1054+
// Non-byte-aligned types are tricky, skip them.
1055+
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1056+
}
1057+
10501058
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
10511059
///
10521060
/// Example:
@@ -1080,9 +1088,8 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10801088
"scalable vectors are not supported");
10811089

10821090
MemRefType memType = loadOp.getMemRefType();
1083-
if (isa<VectorType>(memType.getElementType()))
1084-
return rewriter.notifyMatchFailure(
1085-
op, "memrefs of vectors are not supported");
1091+
if (!isSupportedMemSinkElementType(memType.getElementType()))
1092+
return rewriter.notifyMatchFailure(op, "unsupported memref element type");
10861093

10871094
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
10881095
if (rankOffset < 0)

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,16 @@ func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
528528
return %1 : f32
529529
}
530530

531+
// CHECK-LABEL: @extract_load_index
532+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
533+
func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
534+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
535+
// CHECK: return %[[RES]] : index
536+
%0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
537+
%1 = vector.extract %0[0] : index from vector<4xindex>
538+
return %1 : index
539+
}
540+
531541
// CHECK-LABEL: @extract_load_scalar_non_zero_off
532542
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
533543
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
@@ -598,6 +608,18 @@ func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvecto
598608
return %1 : f32
599609
}
600610

611+
// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
612+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
613+
func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
614+
// Subbyte types are tricky, ignore them for now.
615+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
616+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
617+
// CHECK: return %[[EXT]] : i1
618+
%0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
619+
%1 = vector.extract %0[0] : i1 from vector<8xi1>
620+
return %1 : i1
621+
}
622+
601623
// CHECK-LABEL: @negative_extract_load_no_single_use
602624
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
603625
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {

0 commit comments

Comments
 (0)