Skip to content

Commit aa84998

Browse files
authored
[mlir] Add strided metadata range dataflow analysis (#161280)
Introduces a dataflow analysis for tracking offset, size, and stride ranges of operations. Inference of the metadata is accomplished through the implementation of the interface `InferStridedMetadataOpInterface`. To keep the size of the patch small, this patch only implements the interface for the `memref.subview` operation. It's future work to add more operations. Example: ```mlir func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> return } ``` Applying `mlir-opt --test-strided-metadata-range-analysis` prints: ``` Op: %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> result[0]: strided_metadata<offset = [{unsigned : [1, 1] signed : [1, 1]}], sizes = [{unsigned : [5, 7] signed : [5, 7]}, {unsigned : [11, 13] signed : [11, 13]}, {unsigned : [2, 2] signed : [2, 2]}], strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]> ``` --------- Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
1 parent 07e4907 commit aa84998

File tree

18 files changed

+661
-2
lines changed

18 files changed

+661
-2
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
10+
#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
11+
12+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13+
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
14+
15+
namespace mlir {
16+
namespace dataflow {
17+
18+
/// This lattice element represents the strided metadata of an SSA value.
19+
class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
20+
public:
21+
using Lattice::Lattice;
22+
};
23+
24+
/// Strided metadata range analysis determines the strided metadata ranges of
25+
/// SSA values using operations that define `InferStridedMetadataInterface`.
26+
///
27+
/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
28+
/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
29+
/// loaded in the same solver context.
30+
class StridedMetadataRangeAnalysis
31+
: public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
32+
public:
33+
StridedMetadataRangeAnalysis(DataFlowSolver &solver,
34+
int32_t indexBitwidth = 64);
35+
36+
/// At an entry point, we cannot reason about strided metadata ranges unless
37+
/// the type also encodes the data. For example, a memref with static layout.
38+
void setToEntryState(StridedMetadataRangeLattice *lattice) override;
39+
40+
/// Visit an operation. Invoke the transfer function on each operation that
41+
/// implements `InferStridedMetadataInterface`.
42+
LogicalResult
43+
visitOperation(Operation *op,
44+
ArrayRef<const StridedMetadataRangeLattice *> operands,
45+
ArrayRef<StridedMetadataRangeLattice *> results) override;
46+
47+
private:
48+
/// Index bitwidth to use when operating with the int-ranges.
49+
int32_t indexBitwidth = 64;
50+
};
51+
} // namespace dataflow
52+
} // end namespace mlir
53+
54+
#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CastInterfaces.h"
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/InferIntRangeInterface.h"
20+
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/MemOpInterfaces.h"
2223
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/InferIntRangeInterface.td"
17+
include "mlir/Interfaces/InferStridedMetadataInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/MemOpInterfaces.td"
1920
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2085,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20852086

20862087
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20872088
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2089+
DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
20882090
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20892091
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20902092
AttrSizedOperandSegments,

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
66
add_mlir_interface(FunctionInterfaces)
77
add_mlir_interface(IndexingMapOpInterface)
88
add_mlir_interface(InferIntRangeInterface)
9+
add_mlir_interface(InferStridedMetadataInterface)
910
add_mlir_interface(InferTypeOpInterface)
1011
add_mlir_interface(LoopLikeInterface)
1112
add_mlir_interface(MemOpInterfaces)

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class IntegerValueRange {
117117
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
118118

119119
/// Create an integer value range lattice value.
120-
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
120+
explicit IntegerValueRange(
121+
std::optional<ConstantIntRanges> value = std::nullopt)
121122
: value(std::move(value)) {}
122123

123124
/// Whether the range is uninitialized. This happens when the state hasn't
@@ -167,6 +168,15 @@ using SetIntRangeFn =
167168
using SetIntLatticeFn =
168169
llvm::function_ref<void(Value, const IntegerValueRange &)>;
169170

171+
/// Helper callback type to get the integer range of a value.
172+
using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
173+
174+
/// Helper function to collect the integer range values of an array of op fold
175+
/// results.
176+
SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
177+
GetIntRangeFn getIntRange,
178+
int32_t indexBitwidth);
179+
170180
class InferIntRangeInterface;
171181

172182
namespace intrange::detail {
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains definitions of the strided metadata inference interface
10+
// defined in `InferStridedMetadataInterface.td`
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
15+
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
16+
17+
#include "mlir/Interfaces/InferIntRangeInterface.h"
18+
19+
namespace mlir {
20+
/// A class that represents the strided metadata range information, including
21+
/// offsets, sizes, and strides as integer ranges.
22+
class StridedMetadataRange {
23+
public:
24+
/// Default constructor creates uninitialized ranges.
25+
StridedMetadataRange() = default;
26+
27+
/// Returns a ranked strided metadata range.
28+
static StridedMetadataRange
29+
getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
30+
SmallVectorImpl<ConstantIntRanges> &&sizes,
31+
SmallVectorImpl<ConstantIntRanges> &&strides) {
32+
return StridedMetadataRange(std::move(offsets), std::move(sizes),
33+
std::move(strides));
34+
}
35+
36+
/// Returns a strided metadata range with maximum ranges.
37+
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
38+
int32_t offsetsRank,
39+
int32_t sizeRank,
40+
int32_t stridedRank) {
41+
return StridedMetadataRange(
42+
SmallVector<ConstantIntRanges>(
43+
offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
44+
SmallVector<ConstantIntRanges>(
45+
sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
46+
SmallVector<ConstantIntRanges>(
47+
stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
48+
}
49+
50+
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
51+
int32_t rank) {
52+
return getMaxRanges(indexBitwidth, 1, rank, rank);
53+
}
54+
55+
/// Returns whether the metadata is uninitialized.
56+
bool isUninitialized() const { return !offsets.has_value(); }
57+
58+
/// Get the offsets range.
59+
ArrayRef<ConstantIntRanges> getOffsets() const {
60+
return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
61+
}
62+
MutableArrayRef<ConstantIntRanges> getOffsets() {
63+
return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
64+
}
65+
66+
/// Get the sizes ranges.
67+
ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
68+
MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
69+
70+
/// Get the strides ranges.
71+
ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
72+
MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
73+
74+
/// Compare two strided metadata ranges.
75+
bool operator==(const StridedMetadataRange &other) const {
76+
return offsets == other.offsets && sizes == other.sizes &&
77+
strides == other.strides;
78+
}
79+
80+
/// Print the strided metadata range.
81+
void print(raw_ostream &os) const;
82+
83+
/// Join two strided metadata ranges, by taking the element-wise union of the
84+
/// metadata.
85+
static StridedMetadataRange join(const StridedMetadataRange &lhs,
86+
const StridedMetadataRange &rhs) {
87+
if (lhs.isUninitialized())
88+
return rhs;
89+
if (rhs.isUninitialized())
90+
return lhs;
91+
92+
// Helper fuction to compute the range union of constant ranges.
93+
auto rangeUnion =
94+
+[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
95+
-> ConstantIntRanges {
96+
return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
97+
};
98+
99+
// Get the elementwise range union. Note, that `zip_equal` will assert if
100+
// sizes are not equal.
101+
SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
102+
llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
103+
SmallVector<ConstantIntRanges> sizes =
104+
llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
105+
SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
106+
llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
107+
108+
// Return the joined metadata.
109+
return StridedMetadataRange(std::move(offsets), std::move(sizes),
110+
std::move(strides));
111+
}
112+
113+
private:
114+
/// Create a strided metadata range with the given offset, sizes, and strides.
115+
StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
116+
SmallVectorImpl<ConstantIntRanges> &&sizes,
117+
SmallVectorImpl<ConstantIntRanges> &&strides)
118+
: offsets(std::move(offsets)), sizes(std::move(sizes)),
119+
strides(std::move(strides)) {}
120+
121+
/// The offsets range.
122+
std::optional<SmallVector<ConstantIntRanges>> offsets;
123+
124+
/// The sizes ranges.
125+
SmallVector<ConstantIntRanges> sizes;
126+
127+
/// The strides ranges.
128+
SmallVector<ConstantIntRanges> strides;
129+
};
130+
131+
/// Print the strided metadata to `os`.
132+
inline raw_ostream &operator<<(raw_ostream &os,
133+
const StridedMetadataRange &range) {
134+
range.print(os);
135+
return os;
136+
}
137+
138+
/// Callback function type for setting the strided metadata of a value.
139+
using SetStridedMetadataRangeFn =
140+
function_ref<void(Value, const StridedMetadataRange &)>;
141+
} // end namespace mlir
142+
143+
#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
144+
145+
#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the interface for strided metadata range analysis
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
14+
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def InferStridedMetadataOpInterface :
19+
OpInterface<"InferStridedMetadataOpInterface"> {
20+
let description = [{
21+
Allows operations to participate in strided metadata analysis by providing
22+
methods that allow them to specify bounds on offsets, sizes, and strides
23+
of their result(s) given bounds on their input(s) if known.
24+
}];
25+
let cppNamespace = "::mlir";
26+
27+
let methods = [
28+
InterfaceMethod<[{
29+
Infer the strided metadata bounds on the results of this op given
30+
the bounds on its operands.
31+
For each result value or block argument of interest, the method should
32+
call `setMetadata` with that `Value` as an argument.
33+
The `operands` parameter contains the strided metadata ranges for all the
34+
operands of the operation in order.
35+
The `getIntRange` callback is provided for obtaining the int-range
36+
analysis result for a given value.
37+
}],
38+
"void", "inferStridedMetadataRanges",
39+
(ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
40+
"::mlir::GetIntRangeFn":$getIntRange,
41+
"::mlir::SetStridedMetadataRangeFn":$setMetadata,
42+
"int32_t":$indexBitwidth)>
43+
];
44+
}
45+
#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE

mlir/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
4040
DataFlow/IntegerRangeAnalysis.cpp
4141
DataFlow/LivenessAnalysis.cpp
4242
DataFlow/SparseAnalysis.cpp
43+
DataFlow/StridedMetadataRangeAnalysis.cpp
4344

4445
ADDITIONAL_HEADER_DIRS
4546
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis

0 commit comments

Comments
 (0)