Skip to content

[mlir][vector] Sink vector.extract/splat into load/store ops #134389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
ArithBuilder(
OpBuilder &b, Location loc,
arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
: b(b), loc(loc), ovf(ovf) {}

Value _and(Value lhs, Value rhs);
Value add(Value lhs, Value rhs);
Expand All @@ -114,6 +117,15 @@ struct ArithBuilder {
private:
OpBuilder &b;
Location loc;
arith::IntegerOverflowFlags ovf;
};

/// ArithBuilder specialized specifically for tensor/memref indexing
/// calculations. Those calculations generally should never signed overflow and
/// always use signed integers, so we can set oveflow flags accordingly.
struct ArithIndexingBuilder : public ArithBuilder {
ArithIndexingBuilder(OpBuilder &b, Location loc)
: ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
};

namespace arith {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Patterns that remove redundant Vector Ops by re-ordering them with
e.g. elementwise Ops:
e.g. elementwise Ops.

Example:
```
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
Expand All @@ -469,8 +471,32 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
At the moment, these patterns are limited to vector.broadcast and
vector.transpose.
At the moment, these patterns are limited to vector.broadcast,
vector.transpose and vector.extract.
}];

let assemblyFormat = "attr-dict";
}

def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.sink_mem_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Patterns that replace redundant Vector Ops (followed by
`vector.load`/`vector.store`) with either vector.load/vector.store or
`memref.load`/`memref.store`. Currently limited to 1-element vectors.

Example:
```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```
}];

let assemblyFormat = "attr-dict";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Patterns that remove redundant Vector Ops by merging them with load/store
/// ops
/// ```
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
/// vector.extract %0[1] : f32 from vector<4xf32>
/// ```
/// Gets converted to:
/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a canonicalization pattern iff there's only one use which is a vector.extract. I can't think of a reason why we would want to load the redundant elements. I would clearly document that this only applies to cases with one use/extract op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a canonicalization pattern iff there's only one use which is a vector.extract.

No objections from me, but from a purely maintenance point of view, I'd leave the implementation and most of the tests where they are. Otherwise, we risk "bloating" canonicalization.mlir and e.g. VectorOps.cpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.

Copy link
Member

@kuhar kuhar Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a reason why we would want to load the redundant elements

Increasing the granularity of memory accesses may cause you not to be able to use wider load/store instructions, and undoing this later on and proving that you can use a wider memory access may be hard. We'd be losing information about how many bits are dereferencable and potentially misaligning the access.
This may also change the semantics of load instructions that support OOB behavior -- you can turn an OOB access into an in-bounds access.

For this reason, I don't think this should be on by default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also vector.extract ... [5] : vector<8xi1>. Applying the pattern in this case means loses byte alignment which also doesn't seem like a good fit for a canonicalization to me.

Copy link
Member

@kuhar kuhar Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is constrained to using/extracting only one element so we wouldn't be dropping access pattern information for that case, right? Do you have something else in mind?

Say one of your memory regions is dword-sized but your memory accesses take byte offsets:

%x = vector.load ... : vector<4xi8>
%y = vector.extract %x [2]: i8

The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8, you may no longer know, once the index calculation simplifies with something else, that to get an aligned dword load you you need to also load the preceding bytes vs. only the bytes following this i8 (unaligned). You could resolve that with some masking + shifting, but that comes with some overhead.

Could you help me understand this? We should be able to remove any load operation that is dead, regardless of whether it's in-bounds or OOB, right? What makes this case different?

For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8> have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1> are either loading from memref<...xi1> for which semantics is fully consistent, or loading from memref<...xvector<8xi1>> which is ignored by current pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying this pattern to a vector of bits would lead to memref.load %src[%idx] : memref<8xi1>, i.e. a load of a single bit. That doesn't feel sane.

Also, in cases like this:

%x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1

vector load is probably just a scalar load anyway.

My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes.

(*) Multi-byte - at least one byte.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kuhar, those examples were helpful! I'm still kind of borderline but let’s move forward with this as an independent pattern. The proliferation of dangling “populate” methods is concerning but this case may be worth it.

The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8, you may no longer know,

For that example, I would expect the alignment information to be explicit somewhere as vector.load doesn’t have any default alignment. In the presence of no alignment information, I’m still not sure this transformation is dropping information.

For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8> have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.

Yes but we can’t attribute hardware-specific semantics to vector.load. We allow OOB reads to accommodate those targets that can “handle” OOB accesses. However, we can’t make assumptions on what the target will do or the actual values of those OOB elements. Doc may need some refinement but we defined it along those lines:

Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific. No assumptions should be made on the value of elements loaded out of bounds. Not all targets may support out-of-bounds vector loads. 

A valid lowering of vector.load could be a scalarized version of it that is checking element by element if it’s OOB and only load in-bounds elements so the OOB accesses might not happen. I'd even say that OOB accesses are not observable as using the OOB elements should be poison, right? I think the behavior you are describing would better fit a masked vector load where the masked-off elements (OOB) are replaced with a padding value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1> are either loading from memref<...xi1> for which semantics is fully consistent, or loading from memref<...xvector<8xi1>> which is ignored by current pattern.

I'd be surprised if there is no issue with the data layout as the vector one assumes a packed layout and the scalar one would be unpacked. Looking at the generated LLVM IR for both cases would help

void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
/// cheaper than vector reduction.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
Value ArithBuilder::add(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::AddFOp>(loc, lhs, rhs);
return b.create<arith::AddIOp>(loc, lhs, rhs);
return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::SubFOp>(loc, lhs, rhs);
return b.create<arith::SubIOp>(loc, lhs, rhs);
return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::MulFOp>(loc, lhs, rhs);
return b.create<arith::MulIOp>(loc, lhs, rhs);
return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
vector::populateSinkVectorOpsPatterns(patterns);
}

void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateSinkVectorMemOpsPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
155 changes: 155 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
};

/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
Expand Down Expand Up @@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
/// This may result in cleaner code when extracting a single value
/// from multi-element vector and also to help canonicalize 1-element vectors to
/// scalars.
///
/// Example:
/// ```
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
Expand Down Expand Up @@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final
}
};

/// Check if the element type is suitable for vector.load/store sinking.
/// Element type must be index or byte-aligned integer or floating-point type.
static bool isSupportedMemSinkElementType(Type type) {
if (isa<IndexType>(type))
return true;

return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
}

/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
/// Only index and byte-aligned integer and floating-point element types are
/// supported for now.
///
/// Example:
/// ```
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
/// vector.extract %0[1] : f32 from vector<4xf32>
/// ```
/// Gets converted to:
/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
/// ```
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
if (!loadOp)
return rewriter.notifyMatchFailure(op, "expected a load op");

// Checking for single use so we won't duplicate load ops.
if (!loadOp->hasOneUse())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If moving this to canonicalization, I would add a comment here stating that this condition is the one that makes this a canonicalization pattern and shouldn't be changed.

return rewriter.notifyMatchFailure(op, "expected single op use");

VectorType loadVecType = loadOp.getVectorType();
if (loadVecType.isScalable())
return rewriter.notifyMatchFailure(op,
"scalable vectors are not supported");

MemRefType memType = loadOp.getMemRefType();

// Non-byte-aligned types are tricky and may require special handling,
// ignore them for now.
if (!isSupportedMemSinkElementType(memType.getElementType()))
return rewriter.notifyMatchFailure(op, "unsupported element type");

int64_t rankOffset = memType.getRank() - loadVecType.getRank();
if (rankOffset < 0)
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");

auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
int64_t finalRank = 0;
if (extractVecType)
finalRank = extractVecType.getRank();

SmallVector<Value> indices = loadOp.getIndices();
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();

// There may be memory stores between the load and the extract op, so we
// need to make sure that the new load op is inserted at the same place as
// the original load op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
if (isConstantIntValue(pos, 0))
continue;

Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
indices[i] = idxBuilderf.add(indices[i], offset);
}

Value base = loadOp.getBase();
if (extractVecType) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
indices);
} else {
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
}
// We checked for single use so we can safely erase the load op.
rewriter.eraseOp(loadOp);
return success();
}
};

/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
///
/// Example:
/// ```
/// %0 = vector.splat %arg2 : vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
/// ```
/// Gets converted to:
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
class StoreOpFromSplatOrBroadcast final
: public OpRewritePattern<vector::StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::StoreOp op,
PatternRewriter &rewriter) const override {
VectorType vecType = op.getVectorType();
if (vecType.isScalable())
return rewriter.notifyMatchFailure(op,
"scalable vectors are not supported");

if (isa<VectorType>(op.getMemRefType().getElementType()))
return rewriter.notifyMatchFailure(
op, "memrefs of vectors are not supported");

if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");

Operation *splat = op.getValueToStore().getDefiningOp();
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");

// Checking for single use so we can remove splat.
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");

Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();

if (isa<VectorType>(source.getType())) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
} else {
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
}
rewriter.eraseOp(splat);
return success();
}
};

// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
Expand Down Expand Up @@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
patterns.getContext(), benefit);
}

void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
patterns.getContext(), benefit);
}

void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
Expand Down
1 change: 1 addition & 0 deletions mlir/test/Dialect/Vector/vector-sink-transform.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.sink_ops
transform.apply_patterns.vector.sink_mem_ops
} : !transform.any_op
transform.yield
}
Expand Down
Loading
Loading