Skip to content

Commit e865351

Browse files
use IntegerRangeAnalysis and update launchOp::inferResultRanges.
1 parent c834f4d commit e865351

File tree

6 files changed

+74
-90
lines changed

6 files changed

+74
-90
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
4343
/// constant trip count in non-trivial cases.
4444
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
4545

46-
/// In some scenarios, such as GPU, the number of trip of each thread in the
47-
/// loop is inconsistent. This function returns the maximum number of trip.
48-
std::optional<uint64_t> getMaxConstantTripCount(AffineForOp forOp);
46+
/// Returns the maximum trip count when the operand of forOp has a range. If the
47+
/// operand of forOp is a constant, the return value is the same as
48+
/// `getConstantTripCount`.
49+
std::optional<uint64_t> getUpperBoundOnTripCount(AffineForOp forOp);
4950

5051
/// Returns the greatest known integral divisor of the trip count. Affine
5152
/// expression analysis is used (indirectly through getTripCount), and

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,12 +1035,6 @@ def GPU_LaunchOp : GPU_Op<"launch", [
10351035
static StringRef getNumWorkgroupAttributionsAttrName() {
10361036
return "workgroup_attributions";
10371037
}
1038-
1039-
/// Find BlockSize via the BlockArgument of gpu.launch.
1040-
Value getBlockSizeOnAxis(Value threadId);
1041-
1042-
/// Find BlockSize via the Dimension Information.
1043-
Value getBlockSizeOnAxis(Dimension dimension);
10441038
}];
10451039

10461040
let hasCanonicalizer = 1;

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212

1313
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1414

15+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
16+
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1517
#include "mlir/Analysis/SliceAnalysis.h"
1618
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1719
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1820
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
1921
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2022
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21-
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
23+
#include "mlir/Interfaces/FunctionInterfaces.h"
2224
#include "llvm/Support/MathExtras.h"
2325

2426
#include "llvm/ADT/DenseSet.h"
@@ -31,6 +33,7 @@
3133

3234
using namespace mlir;
3335
using namespace mlir::affine;
36+
using namespace mlir::dataflow;
3437

3538
#define DEBUG_TYPE "affine-loop-analysis"
3639

@@ -85,48 +88,54 @@ void mlir::affine::getTripCountMapAndOperands(
8588
tripCountValueMap.getOperands().end());
8689
}
8790

88-
/// Replace thread_id with its maximum value, if `replaceWithZero` is true,
89-
/// thread_id will be replaced by its minimum value 0.
90-
static void replaceGPUOperands(AffineForOp forOp,
91-
SmallVectorImpl<Value> &operands,
92-
SmallVectorImpl<AffineExpr> &symReplacements,
93-
unsigned numDim, bool replaceWithZero = false) {
94-
auto launchOp = forOp->getParentOfType<gpu::LaunchOp>();
95-
if (!launchOp)
91+
/// By running `IntegerRangeAnalysis` to get the ranges of operand, then fill
92+
/// the `symReplacements` with range. If `replaceByMin` is set to true,
93+
/// construct `replacement` using the smallest value.By default, the largest
94+
/// value will be used for constructing `replacement`.
95+
static void replaceOperandByRange(AffineForOp forOp,
96+
SmallVectorImpl<Value> &operands,
97+
SmallVectorImpl<AffineExpr> &symReplacements,
98+
unsigned numDim, bool replaceByMin = false) {
99+
DataFlowSolver solver;
100+
solver.load<DeadCodeAnalysis>();
101+
solver.load<IntegerRangeAnalysis>();
102+
if (failed(solver.initializeAndRun(
103+
forOp->getParentOfType<FunctionOpInterface>())))
96104
return;
97105

98-
// `b` is only used to create `AffineExpr`.
106+
// `b` is used to create affineExpr
99107
Builder b(forOp.getContext());
100-
unsigned idx = 0;
101-
102108
for (unsigned i = numDim, e = operands.size(); i < e; ++i) {
103109
Value operand = operands[i];
104-
if (Value blockSize = launchOp.getBlockSizeOnAxis(operand)) {
105-
operands[i] = blockSize;
106-
if (!replaceWithZero)
107-
symReplacements.push_back(b.getAffineSymbolExpr(idx++) - 1);
108-
else
109-
symReplacements.push_back(b.getAffineConstantExpr(0));
110+
auto lattice =
111+
solver.lookupState<dataflow::IntegerValueRangeLattice>(operand);
112+
if (!lattice) {
113+
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
110114
continue;
111115
}
112116

113-
Operation *defOp = operand.getDefiningOp();
114-
if (!defOp) {
115-
++idx;
117+
if (lattice->getValue().isUninitialized()) {
118+
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
116119
continue;
117120
}
118121

119-
if (auto threadIdOp = mlir::dyn_cast<gpu::ThreadIdOp>(defOp)) {
120-
gpu::Dimension dimension = threadIdOp.getDimension();
121-
operands[i] = launchOp.getBlockSizeOnAxis(dimension);
122-
if (!replaceWithZero)
123-
symReplacements.push_back(b.getAffineSymbolExpr(idx++) - 1);
124-
else
125-
symReplacements.push_back(b.getAffineConstantExpr(0));
122+
ConstantIntRanges range = lattice->getValue().getValue();
123+
APInt max = range.smax();
124+
APInt min = range.smin();
125+
unsigned bitNums = max.getBitWidth();
126+
127+
if (APInt::getSignedMaxValue(bitNums) == max &&
128+
APInt::getSignedMinValue(bitNums) == min) {
129+
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
126130
continue;
127131
}
128-
++idx;
132+
133+
if (!replaceByMin)
134+
symReplacements.push_back(b.getAffineConstantExpr(max.getZExtValue()));
135+
else
136+
symReplacements.push_back(b.getAffineConstantExpr(min.getZExtValue()));
129137
}
138+
return;
130139
}
131140

132141
/// Take the min if all trip counts are constant.
@@ -158,32 +167,28 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
158167
if (!map)
159168
return std::nullopt;
160169
SmallVector<AffineExpr, 4> symReplacements;
161-
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims());
170+
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims());
162171
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
163172
map.getNumSymbols());
164-
affine::AffineValueMap valueMap(map, operands);
165-
(void)valueMap.canonicalize();
166-
map = valueMap.getAffineMap();
167173
return getConstantTripCountFromAffineMap(map);
168174
}
169175

170-
/// In some scenarios, such as GPU, the number of trip of each thread in the
171-
/// loop is inconsistent. This function returns the maximum number of trip.
176+
/// Returns the maximum trip count when the operand of forOp has a range. If the
177+
/// operand of forOp is a constant, the return value is the same as
178+
/// `getConstantTripCount`.
172179
std::optional<uint64_t>
173-
mlir::affine::getMaxConstantTripCount(AffineForOp forOp) {
180+
mlir::affine::getUpperBoundOnTripCount(AffineForOp forOp) {
174181
SmallVector<Value, 4> operands;
175182
AffineMap map;
176183
getTripCountMapAndOperands(forOp, &map, &operands);
177184

178185
if (!map)
179186
return std::nullopt;
180187
SmallVector<AffineExpr, 4> symReplacements;
181-
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims(), true);
188+
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims(),
189+
true);
182190
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
183191
map.getNumSymbols());
184-
affine::AffineValueMap valueMap(map, operands);
185-
(void)valueMap.canonicalize();
186-
map = valueMap.getAffineMap();
187192
return getConstantTripCountFromAffineMap(map);
188193
}
189194

@@ -198,12 +203,9 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
198203
if (!map)
199204
return 1;
200205
SmallVector<AffineExpr, 4> symReplacements;
201-
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims());
206+
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims());
202207
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
203208
map.getNumSymbols());
204-
affine::AffineValueMap valueMap(map, operands);
205-
(void)valueMap.canonicalize();
206-
map = valueMap.getAffineMap();
207209
// The largest divisor of the trip count is the GCD of the individual largest
208210
// divisors.
209211
assert(map.getNumResults() >= 1 && "expected one or more results");

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
1818
#include "mlir/Dialect/Affine/Utils.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20-
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2120
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2221
#include "mlir/Dialect/SCF/IR/SCF.h"
2322
#include "mlir/IR/IRMapping.h"
@@ -118,7 +117,7 @@ static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
118117
/// was known to have a single iteration.
119118
LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
120119
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
121-
std::optional<uint64_t> maxTripCount = getMaxConstantTripCount(forOp);
120+
std::optional<uint64_t> maxTripCount = getUpperBoundOnTripCount(forOp);
122121
if (!tripCount || *tripCount != 1 || !maxTripCount || *maxTripCount != 1)
123122
return failure();
124123

@@ -888,7 +887,7 @@ void mlir::affine::getTileableBands(
888887
LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
889888
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
890889
std::optional<uint64_t> maxMayBeConstantTripCount =
891-
getMaxConstantTripCount(forOp);
890+
getUpperBoundOnTripCount(forOp);
892891

893892
if (!mayBeConstantTripCount.has_value() &&
894893
!maxMayBeConstantTripCount.has_value())
@@ -1025,7 +1024,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
10251024

10261025
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
10271026
std::optional<uint64_t> maxMayBeConstantTripCount =
1028-
getMaxConstantTripCount(forOp);
1027+
getUpperBoundOnTripCount(forOp);
10291028
if (unrollFactor == 1) {
10301029
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
10311030
maxMayBeConstantTripCount && *maxMayBeConstantTripCount == 1 &&

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -799,26 +799,6 @@ std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
799799
return KernelDim3{operands[6], operands[7], operands[8]};
800800
}
801801

802-
Value LaunchOp::getBlockSizeOnAxis(Dimension dimension) {
803-
if (dimension == Dimension::x)
804-
return getBlockSizeX();
805-
else if (dimension == Dimension::y)
806-
return getBlockSizeY();
807-
else
808-
return getBlockSizeZ();
809-
}
810-
811-
Value LaunchOp::getBlockSizeOnAxis(Value threadId) {
812-
KernelDim3 threadIds = getThreadIds();
813-
if (threadIds.x == threadId)
814-
return getBlockSizeX();
815-
else if (threadIds.y == threadId)
816-
return getBlockSizeY();
817-
else if (threadIds.z == threadId)
818-
return getBlockSizeZ();
819-
return {};
820-
}
821-
822802
LogicalResult LaunchOp::verify() {
823803
if (!(hasClusterSize()) &&
824804
(getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))

mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -250,26 +250,34 @@ void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
250250
void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
251251
SetIntRangeFn setResultRange) {
252252
auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
253-
Value idxResult) {
253+
Value idxResult, Value size) {
254254
if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
255255
return;
256-
ConstantIntRanges dimRange =
257-
argRange.intersection(getIndexRange(1, kMaxDim));
258-
setResultRange(dimResult, dimRange);
259-
ConstantIntRanges idxRange =
260-
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
261-
setResultRange(idxResult, idxRange);
256+
APInt sizeInt;
257+
if (matchPattern(size, m_ConstantInt(&sizeInt))) {
258+
ConstantIntRanges dimRange = ConstantIntRanges::constant(sizeInt);
259+
setResultRange(dimResult, dimRange);
260+
ConstantIntRanges idxRange = getIndexRange(0, sizeInt.getZExtValue() - 1);
261+
setResultRange(idxResult, idxRange);
262+
} else {
263+
ConstantIntRanges dimRange =
264+
argRange.intersection(getIndexRange(1, kMaxDim));
265+
setResultRange(dimResult, dimRange);
266+
ConstantIntRanges idxRange =
267+
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
268+
setResultRange(idxResult, idxRange);
269+
}
262270
};
263271

264272
argRanges = argRanges.drop_front(getAsyncDependencies().size());
265273
KernelDim3 gridDims = getGridSize();
266274
KernelDim3 blockIds = getBlockIds();
267-
setRange(argRanges[0], gridDims.x, blockIds.x);
268-
setRange(argRanges[1], gridDims.y, blockIds.y);
269-
setRange(argRanges[2], gridDims.z, blockIds.z);
275+
setRange(argRanges[0], gridDims.x, blockIds.x, getGridSizeX());
276+
setRange(argRanges[1], gridDims.y, blockIds.y, getGridSizeY());
277+
setRange(argRanges[2], gridDims.z, blockIds.z, getGridSizeZ());
270278
KernelDim3 blockDims = getBlockSize();
271279
KernelDim3 threadIds = getThreadIds();
272-
setRange(argRanges[3], blockDims.x, threadIds.x);
273-
setRange(argRanges[4], blockDims.y, threadIds.y);
274-
setRange(argRanges[5], blockDims.z, threadIds.z);
280+
setRange(argRanges[3], blockDims.x, threadIds.x, getBlockSizeX());
281+
setRange(argRanges[4], blockDims.y, threadIds.y, getBlockSizeY());
282+
setRange(argRanges[5], blockDims.z, threadIds.z, getBlockSizeZ());
275283
}

0 commit comments

Comments
 (0)