Skip to content

Commit 428e4e9

Browse files
[MLIR][MemRef] Normalize memref.alloc ops with non trivial layout map
1 parent 6d93280 commit 428e4e9

File tree

7 files changed

+148
-74
lines changed

7 files changed

+148
-74
lines changed

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,11 @@ struct SemiAffineExprFlattener : public AffineExprFlattener {
118118
// with a positive value." (see AffineExprKind in AffineExpr.h). If this
119119
// assumption does not hold constraints (added above) are a contradiction.
120120

121+
return success();
122+
} else if (localExpr.getKind() == AffineExprKind::Mul) {
123+
(void)localVarCst.appendVar(VarKind::Local);
121124
return success();
122125
}
123-
124126
// TODO: Support other semi-affine expressions.
125127
return failure();
126128
}
@@ -163,7 +165,6 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
163165

164166
return success();
165167
};
166-
167168
if (addConservativeSemiAffineBounds) {
168169
SemiAffineExprFlattener flattener(numDims, numSymbols);
169170
return flattenExprs(flattener);
@@ -229,7 +230,8 @@ LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
229230
assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
230231

231232
std::vector<SmallVector<int64_t, 8>> flatExprs;
232-
if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
233+
if (failed(flattenAlignedMapAndMergeLocals(
234+
other, &flatExprs, /*addConservativeSemiAffineBounds=*/true)))
233235
return failure();
234236
assert(flatExprs.size() == other.getNumResults());
235237

@@ -796,8 +798,6 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
796798
<< "composition unimplemented for semi-affine maps\n");
797799
return failure();
798800
}
799-
800-
// Add localCst information.
801801
if (localCst.getNumLocalVars() > 0) {
802802
unsigned numLocalVars = getNumLocalVars();
803803
// Insert local dims of localCst at the beginning.

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,7 +1786,6 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
17861786
}
17871787
}
17881788

1789-
// TODO: Currently works for static memrefs with a single layout map.
17901789
template <typename AllocLikeOp>
17911790
LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
17921791
MemRefType memrefType = allocOp->getType();
@@ -1799,7 +1798,6 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
17991798
// Either memrefType already had an identity map or the map couldn't be
18001799
// transformed to an identity map.
18011800
return failure();
1802-
18031801
Value oldMemRef = allocOp->getResult();
18041802

18051803
SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
@@ -1819,8 +1817,40 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
18191817
b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
18201818
allocOp->getAlignmentAttr());
18211819
} else {
1822-
newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
1823-
allocOp->getAlignmentAttr());
1820+
mlir::ValueRange dynamicSizes = allocOp->getDynamicSizes();
1821+
mlir::ValueRange symbolOperands = allocOp->getSymbolOperands();
1822+
ArrayRef<int64_t> newShape = newMemRefType.getShape();
1823+
ArrayRef<int64_t> oldShape = memrefType.getShape();
1824+
SmallVector<Value> mapOperands(oldShape.size() + symbolOperands.size());
1825+
SmallVector<Value> dimensionOperands;
1826+
unsigned dimId = 0, symId = 0;
1827+
// Collect all the map operands of `allocOp` (both dynamic sizes and symbol
1828+
// operands), which will help us to compute the dynamic sizes of the new
1829+
// alloc op we are going to create.
1830+
for (unsigned i = 0, e = oldShape.size(); i < e; i++) {
1831+
if (oldShape[i] == ShapedType::kDynamic)
1832+
mapOperands[i] = dynamicSizes[dimId++];
1833+
else
1834+
mapOperands[i] =
1835+
b.create<arith::ConstantIndexOp>(allocOp->getLoc(), oldShape[i]);
1836+
}
1837+
for (unsigned i = oldShape.size(), e = mapOperands.size(); i < e; i++)
1838+
mapOperands[i] = symbolOperands[symId++];
1839+
// Compute the dynamic sizes operands for the new alloc op. If `newShape` is
1840+
// dynamic along a dimension, compute its shape using the layout map and
1841+
// dynamic sizes and symbol operands of the old `allocOp`.
1842+
for (unsigned i = 0, e = newShape.size(); i < e; i++) {
1843+
if (newShape[i] != ShapedType::kDynamic)
1844+
continue;
1845+
AffineExpr resExpr = layoutMap.getResult(i);
1846+
auto resMap = AffineMap::get(layoutMap.getNumDims(),
1847+
layoutMap.getNumSymbols(), resExpr);
1848+
dimensionOperands.push_back(
1849+
b.create<AffineApplyOp>(allocOp->getLoc(), resMap, mapOperands));
1850+
}
1851+
newAlloc =
1852+
b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
1853+
dimensionOperands, allocOp->getAlignmentAttr());
18241854
}
18251855
// Replace all uses of the old memref.
18261856
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1868,11 +1898,8 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
18681898

18691899
// Normalize only static memrefs and dynamic memrefs with a tiled-layout map
18701900
// for now.
1871-
// TODO: Normalize the other types of dynamic memrefs.
18721901
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
18731902
(void)getTileSizePos(layoutMap, tileSizePos);
1874-
if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1875-
return memrefType;
18761903

18771904
// We have a single map that is not an identity map. Create a new memref
18781905
// with the right shape and an identity layout map.
@@ -1894,7 +1921,6 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
18941921
unsigned newRank = layoutMap.getNumResults();
18951922
if (failed(fac.composeMatchingMap(layoutMap)))
18961923
return memrefType;
1897-
// TODO: Handle semi-affine maps.
18981924
// Project out the old data dimensions.
18991925
fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
19001926
SmallVector<int64_t, 4> newShape(newRank);
@@ -1910,14 +1936,14 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
19101936
// For a static memref and an affine map with no symbols, this is
19111937
// always bounded. However, when we have symbols, we may not be able to
19121938
// obtain a constant upper bound. Also, mapping to a negative space is
1913-
// invalid for normalization.
1914-
if (!ubConst.has_value() || *ubConst < 0) {
1915-
LLVM_DEBUG(llvm::dbgs()
1916-
<< "can't normalize map due to unknown/invalid upper bound");
1939+
// invalid for normalization. If dimension of new memrefType is dynamic,
1940+
// the value is `ShapedType::kDynamic`.
1941+
if (!ubConst.has_value())
1942+
newShape[d] = ShapedType::kDynamic;
1943+
else if (*ubConst >= 0)
1944+
newShape[d] = *ubConst + 1;
1945+
else
19171946
return memrefType;
1918-
}
1919-
// If dimension of new memrefType is dynamic, the value is -1.
1920-
newShape[d] = *ubConst + 1;
19211947
}
19221948

19231949
// Create the new memref type after trivializing the old layout map.

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,10 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
445445
if (oldMemRefType == newMemRefType)
446446
continue;
447447
// TODO: Assume single layout map. Multiple maps not supported.
448+
// TODO: Semi-affine layout not supported.
448449
AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
449-
if (failed(replaceAllMemRefUsesWith(oldMemRef,
450+
if (!layoutMap.getResult(0).isPureAffine() ||
451+
failed(replaceAllMemRefUsesWith(oldMemRef,
450452
/*newMemRef=*/newMemRef,
451453
/*extraIndices=*/{},
452454
/*indexRemap=*/layoutMap,

mlir/test/Dialect/Affine/memref-bound-check.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,14 @@ func.func @mod_floordiv_nested() {
124124
return
125125
}
126126

127-
// CHECK-LABEL: func @test_semi_affine_bailout
128-
func.func @test_semi_affine_bailout(%N : index) {
127+
// CHECK-LABEL: func @test_semi_affine_access
128+
func.func @test_semi_affine_access(%N : index) {
129129
%B = memref.alloc() : memref<10 x i32>
130130
affine.for %i = 0 to 10 {
131131
%idx = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%i)[%N]
132132
%y = affine.load %B[%idx] : memref<10 x i32>
133-
// expected-error@-1 {{getMemRefRegion: compose affine map failed}}
133+
// expected-error@-1 {{'affine.load' op memref out of upper bound access along dimension #1}}
134+
// expected-error@-2 {{'affine.load' op memref out of lower bound access along dimension #1}}
134135
}
135136
return
136137
}

0 commit comments

Comments
 (0)