Skip to content

Commit e87aa0c

Browse files
authored
[mlir][vector] Sink vector.extract/splat into load/store ops (#134389)
``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
1 parent d20604e commit e87aa0c

File tree

9 files changed

+421
-7
lines changed

9 files changed

+421
-7
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
101101
/// Helper struct to build simple arithmetic quantities with minimal type
102102
/// inference support.
103103
struct ArithBuilder {
104-
ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
104+
ArithBuilder(
105+
OpBuilder &b, Location loc,
106+
arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
107+
: b(b), loc(loc), ovf(ovf) {}
105108

106109
Value _and(Value lhs, Value rhs);
107110
Value add(Value lhs, Value rhs);
@@ -114,6 +117,15 @@ struct ArithBuilder {
114117
private:
115118
OpBuilder &b;
116119
Location loc;
120+
arith::IntegerOverflowFlags ovf;
121+
};
122+
123+
/// ArithBuilder specialized specifically for tensor/memref indexing
124+
/// calculations. Those calculations generally should never signed overflow and
125+
/// always use signed integers, so we can set oveflow flags accordingly.
126+
struct ArithIndexingBuilder : public ArithBuilder {
127+
ArithIndexingBuilder(OpBuilder &b, Location loc)
128+
: ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
117129
};
118130

119131
namespace arith {

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

+29-3
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
458458
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459459
let description = [{
460460
Patterns that remove redundant Vector Ops by re-ordering them with
461-
e.g. elementwise Ops:
461+
e.g. elementwise Ops.
462+
463+
Example:
462464
```
463465
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
464466
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -469,8 +471,32 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
469471
%0 = arith.addf %a, %b : vector<4x2xf32>
470472
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
471473
```
472-
At the moment, these patterns are limited to vector.broadcast and
473-
vector.transpose.
474+
At the moment, these patterns are limited to vector.broadcast,
475+
vector.transpose and vector.extract.
476+
}];
477+
478+
let assemblyFormat = "attr-dict";
479+
}
480+
481+
def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
482+
"apply_patterns.vector.sink_mem_ops",
483+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
484+
let description = [{
485+
Patterns that replace redundant Vector Ops (followed by
486+
`vector.load`/`vector.store`) with either vector.load/vector.store or
487+
`memref.load`/`memref.store`. Currently limited to 1-element vectors.
488+
489+
Example:
490+
```
491+
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
492+
vector.extract %0[1] : f32 from vector<4xf32>
493+
```
494+
Gets converted to:
495+
```
496+
%c1 = arith.constant 1 : index
497+
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
498+
%1 = memref.load %arg0[%0] : memref<?xf32>
499+
```
474500
}];
475501

476502
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

+14
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
161161
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162162
PatternBenefit benefit = 1);
163163

164+
/// Patterns that remove redundant Vector Ops by merging them with load/store
165+
/// ops
166+
/// ```
167+
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
168+
/// vector.extract %0[1] : f32 from vector<4xf32>
169+
/// ```
170+
/// Gets converted to:
171+
/// ```
172+
/// %c1 = arith.constant 1 : index
173+
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
174+
/// %1 = memref.load %arg0[%0] : memref<?xf32>
175+
void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
176+
PatternBenefit benefit = 1);
177+
164178
/// Patterns that fold chained vector reductions. These patterns assume that
165179
/// elementwise operations (e.g., `arith.addf` with vector operands) are
166180
/// cheaper than vector reduction.

mlir/lib/Dialect/Arith/Utils/Utils.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
315315
Value ArithBuilder::add(Value lhs, Value rhs) {
316316
if (isa<FloatType>(lhs.getType()))
317317
return b.create<arith::AddFOp>(loc, lhs, rhs);
318-
return b.create<arith::AddIOp>(loc, lhs, rhs);
318+
return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
319319
}
320320
Value ArithBuilder::sub(Value lhs, Value rhs) {
321321
if (isa<FloatType>(lhs.getType()))
322322
return b.create<arith::SubFOp>(loc, lhs, rhs);
323-
return b.create<arith::SubIOp>(loc, lhs, rhs);
323+
return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
324324
}
325325
Value ArithBuilder::mul(Value lhs, Value rhs) {
326326
if (isa<FloatType>(lhs.getType()))
327327
return b.create<arith::MulFOp>(loc, lhs, rhs);
328-
return b.create<arith::MulIOp>(loc, lhs, rhs);
328+
return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
329329
}
330330
Value ArithBuilder::sgt(Value lhs, Value rhs) {
331331
if (isa<FloatType>(lhs.getType()))

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
212212
vector::populateSinkVectorOpsPatterns(patterns);
213213
}
214214

215+
void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
216+
RewritePatternSet &patterns) {
217+
vector::populateSinkVectorMemOpsPatterns(patterns);
218+
}
219+
215220
//===----------------------------------------------------------------------===//
216221
// Transform op registration
217222
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

+155
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
902902
};
903903

904904
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
905+
///
906+
/// Example:
905907
/// ```
906908
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
907909
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
987989
/// This may result in cleaner code when extracting a single value
988990
/// from multi-element vector and also to help canonicalize 1-element vectors to
989991
/// scalars.
992+
///
993+
/// Example:
990994
/// ```
991995
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
992996
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final
10431047
}
10441048
};
10451049

1050+
/// Check if the element type is suitable for vector.load/store sinking.
1051+
/// Element type must be index or byte-aligned integer or floating-point type.
1052+
static bool isSupportedMemSinkElementType(Type type) {
1053+
if (isa<IndexType>(type))
1054+
return true;
1055+
1056+
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1057+
}
1058+
1059+
/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1060+
/// Only index and byte-aligned integer and floating-point element types are
1061+
/// supported for now.
1062+
///
1063+
/// Example:
1064+
/// ```
1065+
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1066+
/// vector.extract %0[1] : f32 from vector<4xf32>
1067+
/// ```
1068+
/// Gets converted to:
1069+
/// ```
1070+
/// %c1 = arith.constant 1 : index
1071+
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1072+
/// %1 = memref.load %arg0[%0] : memref<?xf32>
1073+
/// ```
1074+
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1075+
public:
1076+
using OpRewritePattern::OpRewritePattern;
1077+
1078+
LogicalResult matchAndRewrite(vector::ExtractOp op,
1079+
PatternRewriter &rewriter) const override {
1080+
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1081+
if (!loadOp)
1082+
return rewriter.notifyMatchFailure(op, "expected a load op");
1083+
1084+
// Checking for single use so we won't duplicate load ops.
1085+
if (!loadOp->hasOneUse())
1086+
return rewriter.notifyMatchFailure(op, "expected single op use");
1087+
1088+
VectorType loadVecType = loadOp.getVectorType();
1089+
if (loadVecType.isScalable())
1090+
return rewriter.notifyMatchFailure(op,
1091+
"scalable vectors are not supported");
1092+
1093+
MemRefType memType = loadOp.getMemRefType();
1094+
1095+
// Non-byte-aligned types are tricky and may require special handling,
1096+
// ignore them for now.
1097+
if (!isSupportedMemSinkElementType(memType.getElementType()))
1098+
return rewriter.notifyMatchFailure(op, "unsupported element type");
1099+
1100+
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1101+
if (rankOffset < 0)
1102+
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1103+
1104+
auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1105+
int64_t finalRank = 0;
1106+
if (extractVecType)
1107+
finalRank = extractVecType.getRank();
1108+
1109+
SmallVector<Value> indices = loadOp.getIndices();
1110+
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1111+
1112+
// There may be memory stores between the load and the extract op, so we
1113+
// need to make sure that the new load op is inserted at the same place as
1114+
// the original load op.
1115+
OpBuilder::InsertionGuard g(rewriter);
1116+
rewriter.setInsertionPoint(loadOp);
1117+
Location loc = loadOp.getLoc();
1118+
ArithIndexingBuilder idxBuilderf(rewriter, loc);
1119+
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1120+
OpFoldResult pos = extractPos[i - rankOffset];
1121+
if (isConstantIntValue(pos, 0))
1122+
continue;
1123+
1124+
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1125+
indices[i] = idxBuilderf.add(indices[i], offset);
1126+
}
1127+
1128+
Value base = loadOp.getBase();
1129+
if (extractVecType) {
1130+
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1131+
indices);
1132+
} else {
1133+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1134+
}
1135+
// We checked for single use so we can safely erase the load op.
1136+
rewriter.eraseOp(loadOp);
1137+
return success();
1138+
}
1139+
};
1140+
1141+
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1142+
///
1143+
/// Example:
1144+
/// ```
1145+
/// %0 = vector.splat %arg2 : vector<1xf32>
1146+
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1147+
/// ```
1148+
/// Gets converted to:
1149+
/// ```
1150+
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1151+
/// ```
1152+
class StoreOpFromSplatOrBroadcast final
1153+
: public OpRewritePattern<vector::StoreOp> {
1154+
public:
1155+
using OpRewritePattern::OpRewritePattern;
1156+
1157+
LogicalResult matchAndRewrite(vector::StoreOp op,
1158+
PatternRewriter &rewriter) const override {
1159+
VectorType vecType = op.getVectorType();
1160+
if (vecType.isScalable())
1161+
return rewriter.notifyMatchFailure(op,
1162+
"scalable vectors are not supported");
1163+
1164+
if (isa<VectorType>(op.getMemRefType().getElementType()))
1165+
return rewriter.notifyMatchFailure(
1166+
op, "memrefs of vectors are not supported");
1167+
1168+
if (vecType.getNumElements() != 1)
1169+
return rewriter.notifyMatchFailure(
1170+
op, "only 1-element vectors are supported");
1171+
1172+
Operation *splat = op.getValueToStore().getDefiningOp();
1173+
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1174+
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1175+
1176+
// Checking for single use so we can remove splat.
1177+
if (!splat->hasOneUse())
1178+
return rewriter.notifyMatchFailure(op, "expected single op use");
1179+
1180+
Value source = splat->getOperand(0);
1181+
Value base = op.getBase();
1182+
ValueRange indices = op.getIndices();
1183+
1184+
if (isa<VectorType>(source.getType())) {
1185+
rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1186+
} else {
1187+
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1188+
}
1189+
rewriter.eraseOp(splat);
1190+
return success();
1191+
}
1192+
};
1193+
10461194
// Helper that returns a vector comparison that constructs a mask:
10471195
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
10481196
//
@@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21092257
patterns.getContext(), benefit);
21102258
}
21112259

2260+
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2261+
PatternBenefit benefit) {
2262+
// TODO: Consider converting these patterns to canonicalizations.
2263+
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2264+
patterns.getContext(), benefit);
2265+
}
2266+
21122267
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
21132268
RewritePatternSet &patterns, PatternBenefit benefit) {
21142269
patterns.add<ChainedReduction>(patterns.getContext(), benefit);

mlir/test/Dialect/Vector/vector-sink-transform.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
77
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
88
transform.apply_patterns to %func {
99
transform.apply_patterns.vector.sink_ops
10+
transform.apply_patterns.vector.sink_mem_ops
1011
} : !transform.any_op
1112
transform.yield
1213
}

0 commit comments

Comments
 (0)