@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
902
902
};
903
903
904
904
// / Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
905
+ // /
906
+ // / Example:
905
907
// / ```
906
908
// / %a = vector.broadcast %arg1 : index to vector<1x4xindex>
907
909
// / %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
987
989
// / This may result in cleaner code when extracting a single value
988
990
// / from multi-element vector and also to help canonicalize 1-element vectors to
989
991
// / scalars.
992
+ // /
993
+ // / Example:
990
994
// / ```
991
995
// / %0 = arith.addf %arg0, %arg1 : vector<4xf32>
992
996
// / %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final
1043
1047
}
1044
1048
};
1045
1049
1050
+ // / Check if the element type is suitable for vector.load/store sinking.
1051
+ // / Element type must be index or byte-aligned integer or floating-point type.
1052
+ static bool isSupportedMemSinkElementType (Type type) {
1053
+ if (isa<IndexType>(type))
1054
+ return true ;
1055
+
1056
+ return type.isIntOrFloat () && type.getIntOrFloatBitWidth () % 8 == 0 ;
1057
+ }
1058
+
1059
+ // / Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1060
+ // / Only index and byte-aligned integer and floating-point element types are
1061
+ // / supported for now.
1062
+ // /
1063
+ // / Example:
1064
+ // / ```
1065
+ // / vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1066
+ // / vector.extract %0[1] : f32 from vector<4xf32>
1067
+ // / ```
1068
+ // / Gets converted to:
1069
+ // / ```
1070
+ // / %c1 = arith.constant 1 : index
1071
+ // / %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1072
+ // / %1 = memref.load %arg0[%0] : memref<?xf32>
1073
+ // / ```
1074
+ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1075
+ public:
1076
+ using OpRewritePattern::OpRewritePattern;
1077
+
1078
+ LogicalResult matchAndRewrite (vector::ExtractOp op,
1079
+ PatternRewriter &rewriter) const override {
1080
+ auto loadOp = op.getVector ().getDefiningOp <vector::LoadOp>();
1081
+ if (!loadOp)
1082
+ return rewriter.notifyMatchFailure (op, " expected a load op" );
1083
+
1084
+ // Checking for single use so we won't duplicate load ops.
1085
+ if (!loadOp->hasOneUse ())
1086
+ return rewriter.notifyMatchFailure (op, " expected single op use" );
1087
+
1088
+ VectorType loadVecType = loadOp.getVectorType ();
1089
+ if (loadVecType.isScalable ())
1090
+ return rewriter.notifyMatchFailure (op,
1091
+ " scalable vectors are not supported" );
1092
+
1093
+ MemRefType memType = loadOp.getMemRefType ();
1094
+
1095
+ // Non-byte-aligned types are tricky and may require special handling,
1096
+ // ignore them for now.
1097
+ if (!isSupportedMemSinkElementType (memType.getElementType ()))
1098
+ return rewriter.notifyMatchFailure (op, " unsupported element type" );
1099
+
1100
+ int64_t rankOffset = memType.getRank () - loadVecType.getRank ();
1101
+ if (rankOffset < 0 )
1102
+ return rewriter.notifyMatchFailure (op, " unsupported ranks combination" );
1103
+
1104
+ auto extractVecType = dyn_cast<VectorType>(op.getResult ().getType ());
1105
+ int64_t finalRank = 0 ;
1106
+ if (extractVecType)
1107
+ finalRank = extractVecType.getRank ();
1108
+
1109
+ SmallVector<Value> indices = loadOp.getIndices ();
1110
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition ();
1111
+
1112
+ // There may be memory stores between the load and the extract op, so we
1113
+ // need to make sure that the new load op is inserted at the same place as
1114
+ // the original load op.
1115
+ OpBuilder::InsertionGuard g (rewriter);
1116
+ rewriter.setInsertionPoint (loadOp);
1117
+ Location loc = loadOp.getLoc ();
1118
+ ArithIndexingBuilder idxBuilderf (rewriter, loc);
1119
+ for (auto i : llvm::seq<int64_t >(rankOffset, indices.size () - finalRank)) {
1120
+ OpFoldResult pos = extractPos[i - rankOffset];
1121
+ if (isConstantIntValue (pos, 0 ))
1122
+ continue ;
1123
+
1124
+ Value offset = getValueOrCreateConstantIndexOp (rewriter, loc, pos);
1125
+ indices[i] = idxBuilderf.add (indices[i], offset);
1126
+ }
1127
+
1128
+ Value base = loadOp.getBase ();
1129
+ if (extractVecType) {
1130
+ rewriter.replaceOpWithNewOp <vector::LoadOp>(op, extractVecType, base,
1131
+ indices);
1132
+ } else {
1133
+ rewriter.replaceOpWithNewOp <memref::LoadOp>(op, base, indices);
1134
+ }
1135
+ // We checked for single use so we can safely erase the load op.
1136
+ rewriter.eraseOp (loadOp);
1137
+ return success ();
1138
+ }
1139
+ };
1140
+
1141
+ // / Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1142
+ // /
1143
+ // / Example:
1144
+ // / ```
1145
+ // / %0 = vector.splat %arg2 : vector<1xf32>
1146
+ // / vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1147
+ // / ```
1148
+ // / Gets converted to:
1149
+ // / ```
1150
+ // / memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1151
+ // / ```
1152
+ class StoreOpFromSplatOrBroadcast final
1153
+ : public OpRewritePattern<vector::StoreOp> {
1154
+ public:
1155
+ using OpRewritePattern::OpRewritePattern;
1156
+
1157
+ LogicalResult matchAndRewrite (vector::StoreOp op,
1158
+ PatternRewriter &rewriter) const override {
1159
+ VectorType vecType = op.getVectorType ();
1160
+ if (vecType.isScalable ())
1161
+ return rewriter.notifyMatchFailure (op,
1162
+ " scalable vectors are not supported" );
1163
+
1164
+ if (isa<VectorType>(op.getMemRefType ().getElementType ()))
1165
+ return rewriter.notifyMatchFailure (
1166
+ op, " memrefs of vectors are not supported" );
1167
+
1168
+ if (vecType.getNumElements () != 1 )
1169
+ return rewriter.notifyMatchFailure (
1170
+ op, " only 1-element vectors are supported" );
1171
+
1172
+ Operation *splat = op.getValueToStore ().getDefiningOp ();
1173
+ if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1174
+ return rewriter.notifyMatchFailure (op, " neither a splat nor a broadcast" );
1175
+
1176
+ // Checking for single use so we can remove splat.
1177
+ if (!splat->hasOneUse ())
1178
+ return rewriter.notifyMatchFailure (op, " expected single op use" );
1179
+
1180
+ Value source = splat->getOperand (0 );
1181
+ Value base = op.getBase ();
1182
+ ValueRange indices = op.getIndices ();
1183
+
1184
+ if (isa<VectorType>(source.getType ())) {
1185
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(op, source, base, indices);
1186
+ } else {
1187
+ rewriter.replaceOpWithNewOp <memref::StoreOp>(op, source, base, indices);
1188
+ }
1189
+ rewriter.eraseOp (splat);
1190
+ return success ();
1191
+ }
1192
+ };
1193
+
1046
1194
// Helper that returns a vector comparison that constructs a mask:
1047
1195
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1048
1196
//
@@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2109
2257
patterns.getContext (), benefit);
2110
2258
}
2111
2259
2260
+ void mlir::vector::populateSinkVectorMemOpsPatterns (RewritePatternSet &patterns,
2261
+ PatternBenefit benefit) {
2262
+ // TODO: Consider converting these patterns to canonicalizations.
2263
+ patterns.add <ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2264
+ patterns.getContext (), benefit);
2265
+ }
2266
+
2112
2267
void mlir::vector::populateChainedVectorReductionFoldingPatterns (
2113
2268
RewritePatternSet &patterns, PatternBenefit benefit) {
2114
2269
patterns.add <ChainedReduction>(patterns.getContext (), benefit);
0 commit comments