Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Interfaces/InferStridedMetadataInterface.h"

namespace mlir {
namespace dataflow {

/// This lattice element represents the strided metadata of an SSA value.
class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
public:
using Lattice::Lattice;
};

/// Strided metadata range analysis determines the strided metadata ranges of
/// SSA values using operations that define `InferStridedMetadataInterface`.
///
/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
/// loaded in the same solver context.
class StridedMetadataRangeAnalysis
: public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
public:
StridedMetadataRangeAnalysis(DataFlowSolver &solver,
int32_t indexBitwidth = 64);

/// At an entry point, we cannot reason about strided metadata ranges unless
/// the type also encodes the data. For example, a memref with static layout.
void setToEntryState(StridedMetadataRangeLattice *lattice) override;

/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferStridedMetadataInterface`.
LogicalResult
visitOperation(Operation *op,
ArrayRef<const StridedMetadataRangeLattice *> operands,
ArrayRef<StridedMetadataRangeLattice *> results) override;

private:
/// Index bitwidth to use when operating with the int-ranges.
int32_t indexBitwidth = 64;
};
} // namespace dataflow
} // end namespace mlir

#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferStridedMetadataInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
Expand Down Expand Up @@ -2085,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",

def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
Expand Down
12 changes: 11 additions & 1 deletion mlir/include/mlir/Interfaces/InferIntRangeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class IntegerValueRange {
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}

/// Create an integer value range lattice value.
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
explicit IntegerValueRange(
std::optional<ConstantIntRanges> value = std::nullopt)
: value(std::move(value)) {}

/// Whether the range is uninitialized. This happens when the state hasn't
Expand Down Expand Up @@ -167,6 +168,15 @@ using SetIntRangeFn =
using SetIntLatticeFn =
llvm::function_ref<void(Value, const IntegerValueRange &)>;

/// Helper callback type to get the integer range of a value.
using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;

/// Helper function to collect the integer range values of an array of op fold
/// results.
SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
GetIntRangeFn getIntRange,
int32_t indexBitwidth);

class InferIntRangeInterface;

namespace intrange::detail {
Expand Down
145 changes: 145 additions & 0 deletions mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions of the strided metadata inference interface
// defined in `InferStridedMetadataInterface.td`
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H

#include "mlir/Interfaces/InferIntRangeInterface.h"

namespace mlir {
/// A class that represents the strided metadata range information, including
/// offsets, sizes, and strides as integer ranges.
class StridedMetadataRange {
public:
/// Default constructor creates uninitialized ranges.
StridedMetadataRange() = default;

/// Returns a ranked strided metadata range.
static StridedMetadataRange
getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
SmallVectorImpl<ConstantIntRanges> &&sizes,
SmallVectorImpl<ConstantIntRanges> &&strides) {
return StridedMetadataRange(std::move(offsets), std::move(sizes),
std::move(strides));
}

/// Returns a strided metadata range with maximum ranges.
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
int32_t offsetsRank,
int32_t sizeRank,
int32_t stridedRank) {
return StridedMetadataRange(
SmallVector<ConstantIntRanges>(
offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
SmallVector<ConstantIntRanges>(
sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
SmallVector<ConstantIntRanges>(
stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
}

static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
int32_t rank) {
return getMaxRanges(indexBitwidth, 1, rank, rank);
}

/// Returns whether the metadata is uninitialized.
bool isUninitialized() const { return !offsets.has_value(); }

/// Get the offsets range.
ArrayRef<ConstantIntRanges> getOffsets() const {
return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
}
MutableArrayRef<ConstantIntRanges> getOffsets() {
return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
}

/// Get the sizes ranges.
ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }

/// Get the strides ranges.
ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }

/// Compare two strided metadata ranges.
bool operator==(const StridedMetadataRange &other) const {
return offsets == other.offsets && sizes == other.sizes &&
strides == other.strides;
}

/// Print the strided metadata range.
void print(raw_ostream &os) const;

/// Join two strided metadata ranges, by taking the element-wise union of the
/// metadata.
static StridedMetadataRange join(const StridedMetadataRange &lhs,
const StridedMetadataRange &rhs) {
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;

// Helper fuction to compute the range union of constant ranges.
auto rangeUnion =
+[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
-> ConstantIntRanges {
return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
};

// Get the elementwise range union. Note, that `zip_equal` will assert if
// sizes are not equal.
SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
SmallVector<ConstantIntRanges> sizes =
llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);

// Return the joined metadata.
return StridedMetadataRange(std::move(offsets), std::move(sizes),
std::move(strides));
}

private:
/// Create a strided metadata range with the given offset, sizes, and strides.
StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
SmallVectorImpl<ConstantIntRanges> &&sizes,
SmallVectorImpl<ConstantIntRanges> &&strides)
: offsets(std::move(offsets)), sizes(std::move(sizes)),
strides(std::move(strides)) {}

/// The offsets range.
std::optional<SmallVector<ConstantIntRanges>> offsets;

/// The sizes ranges.
SmallVector<ConstantIntRanges> sizes;

/// The strides ranges.
SmallVector<ConstantIntRanges> strides;
};

/// Print the strided metadata to `os`.
inline raw_ostream &operator<<(raw_ostream &os,
const StridedMetadataRange &range) {
range.print(os);
return os;
}

/// Callback function type for setting the strided metadata of a value.
using SetStridedMetadataRangeFn =
function_ref<void(Value, const StridedMetadataRange &)>;
} // end namespace mlir

#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"

#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
45 changes: 45 additions & 0 deletions mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for strided metadata range analysis
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE

include "mlir/IR/OpBase.td"

def InferStridedMetadataOpInterface :
OpInterface<"InferStridedMetadataOpInterface"> {
let description = [{
Allows operations to participate in strided metadata analysis by providing
methods that allow them to specify bounds on offsets, sizes, and strides
of their result(s) given bounds on their input(s) if known.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Infer the strided metadata bounds on the results of this op given
the bounds on its operands.
For each result value or block argument of interest, the method should
call `setMetadata` with that `Value` as an argument.
The `operands` parameter contains the strided metadata ranges for all the
operands of the operation in order.
The `getIntRange` callback is provided for obtaining the int-range
analysis result for a given value.
}],
"void", "inferStridedMetadataRanges",
(ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
"::mlir::GetIntRangeFn":$getIntRange,
"::mlir::SetStridedMetadataRangeFn":$setMetadata,
"int32_t":$indexBitwidth)>
];
}
#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
2 changes: 2 additions & 0 deletions mlir/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
DataFlow/StridedMetadataRangeAnalysis.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
Expand All @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
MLIRInferIntRangeInterface
MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRPresburger
Expand Down
Loading