Skip to content

Commit 6957fc5

Browse files
committed
address reviewer comments
1 parent 670c482 commit 6957fc5

File tree

11 files changed

+94
-151
lines changed

11 files changed

+94
-151
lines changed

mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h

Lines changed: 0 additions & 25 deletions
This file was deleted.

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CastInterfaces.h"
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/InferIntRangeInterface.h"
20+
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/MemOpInterfaces.h"
2223
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/InferIntRangeInterface.td"
17+
include "mlir/Interfaces/InferStridedMetadataInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/MemOpInterfaces.td"
1920
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2084,6 +2085,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20842085

20852086
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20862087
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2088+
DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
20872089
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20882090
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20892091
AttrSizedOperandSegments,

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,15 @@ using SetIntRangeFn =
168168
using SetIntLatticeFn =
169169
llvm::function_ref<void(Value, const IntegerValueRange &)>;
170170

171+
/// Helper callback type to get the integer range of a value.
172+
using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
173+
174+
/// Helper function to collect the integer range values of an array of op fold
175+
/// results.
176+
SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
177+
GetIntRangeFn getIntRange,
178+
int32_t indexBitwidth);
179+
171180
class InferIntRangeInterface;
172181

173182
namespace intrange::detail {

mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ inline raw_ostream &operator<<(raw_ostream &os,
135135
return os;
136136
}
137137

138-
/// Callback function type to get the integer range of a value.
139-
using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
140-
141138
/// Callback function type for setting the strided metadata of a value.
142139
using SetStridedMetadataRangeFn =
143140
function_ref<void(Value, const StridedMetadataRange &)>;

mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def InferStridedMetadataOpInterface :
2828
InterfaceMethod<[{
2929
Infer the strided metadata bounds on the results of this op given
3030
the bounds on its operands.
31-
For each result value or block argument, the method should call
32-
`setMetadata` with that `Value` as an argument.
31+
For each result value or block argument of interest, the method should
32+
call `setMetadata` with that `Value` as an argument.
33+
The `operands` parameter contains the strided metadata ranges for all the
34+
operands of the operation in order.
3335
The `getIntRange` callback is provided for obtaining the int-range
3436
analysis result for a given value.
3537
}],

mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMemRefDialect
22
MemRefDialect.cpp
33
MemRefMemorySlot.cpp
44
MemRefOps.cpp
5-
InferStridedMetadataInterfaceImpl.cpp
65
ValueBoundsOpInterfaceImpl.cpp
76

87
ADDITIONAL_HEADER_DIRS

mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp

Lines changed: 0 additions & 118 deletions
This file was deleted.

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
34373437
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
34383438
}
34393439

3440+
void SubViewOp::inferStridedMetadataRanges(
3441+
ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3442+
SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3443+
auto isUninitialized =
3444+
+[](IntegerValueRange range) { return range.isUninitialized(); };
3445+
3446+
// Bail early if any of the operands metadata is not ready:
3447+
SmallVector<IntegerValueRange> offsetOperands =
3448+
getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3449+
if (llvm::any_of(offsetOperands, isUninitialized))
3450+
return;
3451+
3452+
SmallVector<IntegerValueRange> sizeOperands =
3453+
getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3454+
if (llvm::any_of(sizeOperands, isUninitialized))
3455+
return;
3456+
3457+
SmallVector<IntegerValueRange> stridesOperands =
3458+
getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3459+
if (llvm::any_of(stridesOperands, isUninitialized))
3460+
return;
3461+
3462+
StridedMetadataRange sourceRange =
3463+
ranges[getSourceMutable().getOperandNumber()];
3464+
if (sourceRange.isUninitialized())
3465+
return;
3466+
3467+
ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3468+
3469+
// Get the dropped dims.
3470+
llvm::SmallBitVector droppedDims = getDroppedDims();
3471+
3472+
// Compute the new offset, strides and sizes.
3473+
ConstantIntRanges offset = sourceRange.getOffsets()[0];
3474+
SmallVector<ConstantIntRanges> strides, sizes;
3475+
3476+
for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3477+
bool dropped = droppedDims.test(i);
3478+
// Compute the new offset.
3479+
ConstantIntRanges off =
3480+
intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3481+
offset = intrange::inferAdd({offset, off});
3482+
3483+
// Skip dropped dimensions.
3484+
if (dropped)
3485+
continue;
3486+
// Multiply the strides.
3487+
strides.push_back(
3488+
intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3489+
// Get the sizes.
3490+
sizes.push_back(sizeOperands[i].getValue());
3491+
}
3492+
3493+
setMetadata(getResult(),
3494+
StridedMetadataRange::getRanked(
3495+
SmallVector<ConstantIntRanges>({std::move(offset)}),
3496+
std::move(sizes), std::move(strides)));
3497+
}
3498+
34403499
//===----------------------------------------------------------------------===//
34413500
// TransposeOp
34423501
//===----------------------------------------------------------------------===//

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
146146
return os;
147147
}
148148

149+
SmallVector<IntegerValueRange>
150+
mlir::getIntValueRanges(ArrayRef<OpFoldResult> values,
151+
GetIntRangeFn getIntRange, int32_t indexBitwidth) {
152+
SmallVector<IntegerValueRange> ranges;
153+
ranges.reserve(values.size());
154+
for (OpFoldResult ofr : values) {
155+
if (auto value = dyn_cast<Value>(ofr)) {
156+
ranges.push_back(getIntRange(value));
157+
continue;
158+
}
159+
160+
// Create a constant range.
161+
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
162+
ranges.emplace_back(ConstantIntRanges::constant(
163+
attr.getValue().sextOrTrunc(indexBitwidth)));
164+
}
165+
return ranges;
166+
}
167+
149168
void mlir::intrange::detail::defaultInferResultRanges(
150169
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
151170
SetIntLatticeFn setResultRanges) {

0 commit comments

Comments
 (0)