Skip to content

[mlir][Affine] take strides into account for contiguity check #126579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/MathExtras.h"

#include "llvm/ADT/DenseSet.h"
Expand Down Expand Up @@ -190,7 +193,57 @@ DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv,
return res;
}

// TODO: check access stride.
/// Check that x is an offset in resultExpr
/// That is, check that the result is of the shape x + ...
static bool isOffset(AffineExpr resultExpr, int numDims, ArrayRef<Value> operands, Value offset) {
// Check if the expression is only the offset
if (isa<AffineDimExpr>(resultExpr))
return operands[cast<AffineDimExpr>(resultExpr).getPosition()] == offset;
if (isa<AffineSymbolExpr>(resultExpr))
return operands[cast<AffineSymbolExpr>(resultExpr).getPosition() + numDims] == offset;

// Otherwise, walk through the expression and check that it's of one of the shapes:
// - x + ...
// - (x + ...) mod ...
// The second pattern leads to piecewise contiguous accesses which can be considered contiguous
// for vectorization if the vectorization factor is a divisor of the modulo's left-hand-side
WalkResult walkRes = resultExpr.walk([&](AffineExpr expr) {
if (!isa<AffineBinaryOpExpr>(expr))
return WalkResult::skip();
if (expr.getKind() == AffineExprKind::Add) {
AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
if (auto dimExpr = dyn_cast<AffineDimExpr>(lhs)) {
if (operands[dimExpr.getPosition()] == offset)
return WalkResult::interrupt();
}
else if (auto symExpr = dyn_cast<AffineSymbolExpr>(lhs))
if (operands[symExpr.getPosition() + numDims])
return WalkResult::interrupt();
if (auto dimExpr = dyn_cast<AffineDimExpr>(rhs)) {
if (operands[dimExpr.getPosition()] == offset)
return WalkResult::interrupt();
}
else if (auto symExpr = dyn_cast<AffineSymbolExpr>(rhs))
if (operands[symExpr.getPosition() + numDims])
return WalkResult::interrupt();
return WalkResult::advance();
}
if (expr.getKind() == AffineExprKind::Mod) {
AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
if (auto dimExpr = dyn_cast<AffineDimExpr>(lhs)) {
if (operands[dimExpr.getPosition()] == offset)
return WalkResult::interrupt();
}
else if (auto symExpr = dyn_cast<AffineSymbolExpr>(lhs))
if (operands[symExpr.getPosition() + numDims])
return WalkResult::interrupt();
}
return WalkResult::skip();
});
return walkRes.wasInterrupted();
}

template <typename LoadOrStoreOp>
bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
int *memRefDim) {
Expand Down Expand Up @@ -219,7 +272,17 @@ bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
});
// Check access invariance of each operand in 'exprOperands'.
for (Value exprOperand : exprOperands) {
if (!isAccessIndexInvariant(iv, exprOperand)) {
// Verify that the access is contiguous along the induction variable if it depends on it
// by checking that at most one of the op's access map's result is of the shape IV + constant
auto map = AffineMap::getMultiDimIdentityMap(/*numDims=*/1, iv.getContext());
SmallVector<Value> operands = {exprOperand};
AffineValueMap avm(map, operands);
avm.composeSimplifyAndCanonicalize();
if (avm.isFunctionOf(0, iv)) {
if (!isOffset(resultExpr, numDims, mapOperands, exprOperand) ||
!isOffset(avm.getResult(0), avm.getNumDims(), avm.getOperands(), iv)) {
return false;
}
if (uniqueVaryingIndexAlongIv != -1) {
// 2+ varying indices -> do not vectorize along iv.
return false;
Expand Down
7 changes: 0 additions & 7 deletions mlir/test/Dialect/Affine/access-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@ func.func @loop_simple(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// expected-remark@above {{invariant along loop 1}}
affine.load %A[%c0, 8 * %i + %j] : memref<?x?xf32>
// expected-remark@above {{contiguous along loop 1}}
// Note/FIXME: access stride isn't being checked.
// expected-remark@-3 {{contiguous along loop 0}}

// These are all non-contiguous along both loops. Nothing is emitted.
affine.load %A[%i, %c0] : memref<?x?xf32>
// expected-remark@above {{invariant along loop 1}}
// Note/FIXME: access stride isn't being checked.
affine.load %A[%i, 8 * %j] : memref<?x?xf32>
// expected-remark@above {{contiguous along loop 1}}
affine.load %A[%j, 4 * %i] : memref<?x?xf32>
// expected-remark@above {{contiguous along loop 0}}
}
}
return
Expand Down Expand Up @@ -70,7 +65,6 @@ func.func @tiled(%arg0: memref<*xf32>) {
// expected-remark@above {{invariant along loop 4}}
affine.store %0, %alloc_0[0, %arg1 * -16 + %arg4, 0, %arg3 * -16 + %arg5] : memref<1x16x1x16xf32>
// expected-remark@above {{contiguous along loop 4}}
// expected-remark@above {{contiguous along loop 2}}
// expected-remark@above {{invariant along loop 1}}
}
}
Expand All @@ -79,7 +73,6 @@ func.func @tiled(%arg0: memref<*xf32>) {
affine.for %arg6 = #map(%arg3) to #map1(%arg3) {
%0 = affine.load %alloc_0[0, %arg1 * -16 + %arg4, -%arg2 + %arg5, %arg3 * -16 + %arg6] : memref<1x16x1x16xf32>
// expected-remark@above {{contiguous along loop 5}}
// expected-remark@above {{contiguous along loop 2}}
affine.store %0, %alloc[0, %arg5, %arg6, %arg4] : memref<1x224x224x64xf32>
// expected-remark@above {{contiguous along loop 3}}
// expected-remark@above {{invariant along loop 0}}
Expand Down