-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Sink vector.extract/splat into load/store ops #134389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc53309
e2dd80a
c2ddc12
abf51af
3668826
9b7af3a
1b5b408
cfaef9d
a959b60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( | |
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, | ||
PatternBenefit benefit = 1); | ||
|
||
/// Patterns that remove redundant Vector Ops by merging them with load/store | ||
/// ops | ||
/// ``` | ||
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> | ||
/// vector.extract %0[1] : f32 from vector<4xf32> | ||
/// ``` | ||
/// Gets converted to: | ||
/// ``` | ||
/// %c1 = arith.constant 1 : index | ||
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
/// %1 = memref.load %arg0[%0] : memref<?xf32> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be a canonicalization pattern iff there's only one use which is a vector.extract. I can't think of a reason why we would want to load the redundant elements. I would clearly document that this only applies to cases with one use/extract op. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No objections from me, but from a purely maintenance point of view, I'd leave the implementation and most of the tests where they are. Otherwise, we risk "bloating" canonicalization.mlir and e.g. VectorOps.cpp. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Increasing the granularity of memory accesses may cause you not to be able to use wider load/store instructions, and undoing this later on and proving that you can use a wider memory access may be hard. We'd be losing information about how many bits are dereferencable and potentially misaligning the access. For this reason, I don't think this should be on by default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Say one of your memory regions is dword-sized but your memory accesses take byte offsets: %x = vector.load ... : vector<4xi8>
%y = vector.extract %x [2]: i8 The original load is efficient because you are accessing a full dword. However, if you turn it into
For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applying this pattern to a vector of bits would lead to Also, in cases like this: %x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1 vector load is probably just a scalar load anyway. My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes. (*) Multi-byte - at least one byte. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @kuhar, those examples were helpful! I'm still kind of borderline but let’s move forward with this as an independent pattern. The proliferation of dangling “populate” methods is concerning but this case may be worth it.
For that example, I would expect the alignment information to be explicit somewhere as
Yes but we can’t attribute hardware-specific semantics to
A valid lowering of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd be surprised if there is no issue with the data layout as the vector one assumes a packed layout and the scalar one would be unpacked. Looking at the generated LLVM IR for both cases would help |
||
void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, | ||
PatternBenefit benefit = 1); | ||
|
||
/// Patterns that fold chained vector reductions. These patterns assume that | ||
/// elementwise operations (e.g., `arith.addf` with vector operands) are | ||
/// cheaper than vector reduction. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { | |
}; | ||
|
||
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex> | ||
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex> | ||
|
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final | |
/// This may result in cleaner code when extracting a single value | ||
/// from multi-element vector and also to help canonicalize 1-element vectors to | ||
/// scalars. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32> | ||
/// %1 = vector.extract %0[1] : f32 from vector<4xf32> | ||
|
@@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final | |
} | ||
}; | ||
|
||
/// Check if the element type is suitable for vector.load/store sinking. | ||
/// Element type must be index or byte-aligned integer or floating-point type. | ||
static bool isSupportedMemSinkElementType(Type type) { | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (isa<IndexType>(type)) | ||
return true; | ||
|
||
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0; | ||
} | ||
|
||
/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load. | ||
/// Only index and byte-aligned integer and floating-point element types are | ||
/// supported for now. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> | ||
/// vector.extract %0[1] : f32 from vector<4xf32> | ||
/// ``` | ||
/// Gets converted to: | ||
/// ``` | ||
/// %c1 = arith.constant 1 : index | ||
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
/// %1 = memref.load %arg0[%0] : memref<?xf32> | ||
/// ``` | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::ExtractOp op, | ||
PatternRewriter &rewriter) const override { | ||
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>(); | ||
if (!loadOp) | ||
return rewriter.notifyMatchFailure(op, "expected a load op"); | ||
|
||
// Checking for single use so we won't duplicate load ops. | ||
if (!loadOp->hasOneUse()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If moving this to canonicalization, I would add a comment here stating that this condition is the one that makes this a canonicalization pattern and shouldn't be changed. |
||
return rewriter.notifyMatchFailure(op, "expected single op use"); | ||
|
||
VectorType loadVecType = loadOp.getVectorType(); | ||
if (loadVecType.isScalable()) | ||
return rewriter.notifyMatchFailure(op, | ||
"scalable vectors are not supported"); | ||
|
||
MemRefType memType = loadOp.getMemRefType(); | ||
|
||
// Non-byte-aligned types are tricky and may require special handling, | ||
// ignore them for now. | ||
if (!isSupportedMemSinkElementType(memType.getElementType())) | ||
return rewriter.notifyMatchFailure(op, "unsupported element type"); | ||
|
||
int64_t rankOffset = memType.getRank() - loadVecType.getRank(); | ||
if (rankOffset < 0) | ||
return rewriter.notifyMatchFailure(op, "unsupported ranks combination"); | ||
|
||
auto extractVecType = dyn_cast<VectorType>(op.getResult().getType()); | ||
int64_t finalRank = 0; | ||
if (extractVecType) | ||
finalRank = extractVecType.getRank(); | ||
|
||
SmallVector<Value> indices = loadOp.getIndices(); | ||
SmallVector<OpFoldResult> extractPos = op.getMixedPosition(); | ||
|
||
// There may be memory stores between the load and the extract op, so we | ||
// need to make sure that the new load op is inserted at the same place as | ||
// the original load op. | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(loadOp); | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Location loc = loadOp.getLoc(); | ||
ArithIndexingBuilder idxBuilderf(rewriter, loc); | ||
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) { | ||
OpFoldResult pos = extractPos[i - rankOffset]; | ||
if (isConstantIntValue(pos, 0)) | ||
continue; | ||
|
||
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); | ||
indices[i] = idxBuilderf.add(indices[i], offset); | ||
} | ||
|
||
Value base = loadOp.getBase(); | ||
if (extractVecType) { | ||
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base, | ||
indices); | ||
} else { | ||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices); | ||
} | ||
// We checked for single use so we can safely erase the load op. | ||
rewriter.eraseOp(loadOp); | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return success(); | ||
} | ||
}; | ||
|
||
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// %0 = vector.splat %arg2 : vector<1xf32> | ||
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> | ||
/// ``` | ||
/// Gets converted to: | ||
/// ``` | ||
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32> | ||
/// ``` | ||
class StoreOpFromSplatOrBroadcast final | ||
: public OpRewritePattern<vector::StoreOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::StoreOp op, | ||
PatternRewriter &rewriter) const override { | ||
VectorType vecType = op.getVectorType(); | ||
if (vecType.isScalable()) | ||
return rewriter.notifyMatchFailure(op, | ||
"scalable vectors are not supported"); | ||
|
||
if (isa<VectorType>(op.getMemRefType().getElementType())) | ||
return rewriter.notifyMatchFailure( | ||
op, "memrefs of vectors are not supported"); | ||
|
||
if (vecType.getNumElements() != 1) | ||
return rewriter.notifyMatchFailure( | ||
op, "only 1-element vectors are supported"); | ||
|
||
Operation *splat = op.getValueToStore().getDefiningOp(); | ||
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) | ||
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); | ||
|
||
// Checking for single use so we can remove splat. | ||
if (!splat->hasOneUse()) | ||
return rewriter.notifyMatchFailure(op, "expected single op use"); | ||
|
||
Value source = splat->getOperand(0); | ||
Value base = op.getBase(); | ||
ValueRange indices = op.getIndices(); | ||
|
||
if (isa<VectorType>(source.getType())) { | ||
rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices); | ||
} else { | ||
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices); | ||
} | ||
rewriter.eraseOp(splat); | ||
Hardcode84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return success(); | ||
} | ||
}; | ||
|
||
// Helper that returns a vector comparison that constructs a mask: | ||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] | ||
// | ||
|
@@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, | |
patterns.getContext(), benefit); | ||
} | ||
|
||
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, | ||
PatternBenefit benefit) { | ||
// TODO: Consider converting these patterns to canonicalizations. | ||
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>( | ||
patterns.getContext(), benefit); | ||
} | ||
|
||
void mlir::vector::populateChainedVectorReductionFoldingPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<ChainedReduction>(patterns.getContext(), benefit); | ||
|
Uh oh!
There was an error while loading. Please reload this page.