From 1323129751c614882dbbdecc87baf5451b153e78 Mon Sep 17 00:00:00 2001 From: donald chen Date: Sun, 25 Aug 2024 19:21:47 +0800 Subject: [PATCH] [mlir] [dataflow] Refactoring the definition of program points in data flow analysis (#105656) This patch distinguishes between program points and lattice anchors in data flow analysis, where lattice anchors represent locations where a lattice can be attached, while program points denote points in program execution. Related discussions: https://discourse.llvm.org/t/rfc-unify-the-semantics-of-program-points/80671/8 --- .../mlir/Analysis/DataFlow/DeadCodeAnalysis.h | 16 +- .../mlir/Analysis/DataFlow/DenseAnalysis.h | 37 +-- .../Analysis/DataFlow/IntegerRangeAnalysis.h | 2 +- .../mlir/Analysis/DataFlow/SparseAnalysis.h | 8 +- .../include/mlir/Analysis/DataFlowFramework.h | 236 ++++++++++-------- .../Analysis/DataFlow/DeadCodeAnalysis.cpp | 32 +-- mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 22 +- .../DataFlow/IntegerRangeAnalysis.cpp | 2 +- mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 25 +- mlir/lib/Analysis/DataFlowFramework.cpp | 45 ++-- .../DataFlow/TestDeadCodeAnalysis.cpp | 2 +- .../DataFlow/TestDenseDataFlowAnalysis.h | 2 +- .../lib/Analysis/TestDataFlowFramework.cpp | 12 +- 13 files changed, 237 insertions(+), 204 deletions(-) diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index 10ef8b6ba5843a9..80c8b86c63678a4 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -35,21 +35,21 @@ namespace dataflow { //===----------------------------------------------------------------------===// /// This is a simple analysis state that represents whether the associated -/// program point (either a block or a control-flow edge) is live. +/// lattice anchor (either a block or a control-flow edge) is live. class Executable : public AnalysisState { public: using AnalysisState::AnalysisState; - /// Set the state of the program point to live. + /// Set the state of the lattice anchor to live. ChangeResult setToLive(); - /// Get whether the program point is live. + /// Get whether the lattice anchor is live. bool isLive() const { return live; } /// Print the liveness. void print(raw_ostream &os) const override; - /// When the state of the program point is changed to live, re-invoke + /// When the state of the lattice anchor is changed to live, re-invoke /// subscribed analyses on the operations in the block and on the block /// itself. void onUpdate(DataFlowSolver *solver) const override; @@ -60,8 +60,8 @@ class Executable : public AnalysisState { } private: - /// Whether the program point is live. Optimistically assume that the program - /// point is dead. + /// Whether the lattice anchor is live. Optimistically assume that the lattice + /// anchor is dead. bool live = false; /// A set of analyses that should be updated when this state changes. @@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState { // CFGEdge //===----------------------------------------------------------------------===// -/// This program point represents a control-flow edge between a block and one +/// This lattice anchor represents a control-flow edge between a block and one /// of its successors. class CFGEdge - : public GenericProgramPointBase> { + : public GenericLatticeAnchorBase> { public: using Base::Base; diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 4ad5f3fcd838c09..7917f1e3ba64853 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -91,15 +91,16 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis { const AbstractDenseLattice &before, AbstractDenseLattice *after) = 0; - /// Get the dense lattice after the execution of the given program point. - virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + /// Get the dense lattice after the execution of the given lattice anchor. + virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0; /// Get the dense lattice after the execution of the given program point and - /// add it as a dependency to a program point. That is, every time the lattice - /// after point is updated, the dependent program point must be visited, and - /// the newly triggered visit might update the lattice after dependent. + /// add it as a dependency to a lattice anchor. That is, every time the + /// lattice after anchor is updated, the dependent program point must be + /// visited, and the newly triggered visit might update the lattice after + /// dependent. const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, - ProgramPoint point); + LatticeAnchor anchor); /// Set the dense lattice at control flow entry point and propagate an update /// if it changed. @@ -249,9 +250,9 @@ class DenseForwardDataFlowAnalysis } protected: - /// Get the dense lattice after this program point. - LatticeT *getLattice(ProgramPoint point) override { - return getOrCreate(point); + /// Get the dense lattice on this lattice anchor. + LatticeT *getLattice(LatticeAnchor anchor) override { + return getOrCreate(anchor); } /// Set the dense lattice at control flow entry point and propagate an update @@ -331,16 +332,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { const AbstractDenseLattice &after, AbstractDenseLattice *before) = 0; - /// Get the dense lattice before the execution of the program point. That is, + /// Get the dense lattice before the execution of the lattice anchor. That is, /// before the execution of the given operation or after the execution of the /// block. - virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0; - /// Get the dense lattice before the execution of the program point `point` - /// and declare that the `dependent` program point must be updated every time - /// `point` is. + /// Get the dense lattice before the execution of the program point in + /// `anchor` and declare that the `dependent` program point must be updated + /// every time `point` is. const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, - ProgramPoint point); + LatticeAnchor anchor); /// Set the dense lattice before at the control flow exit point and propagate /// the update if it changed. @@ -500,9 +501,9 @@ class DenseBackwardDataFlowAnalysis } protected: - /// Get the dense lattice at the given program point. - LatticeT *getLattice(ProgramPoint point) override { - return getOrCreate(point); + /// Get the dense lattice at the given lattice anchor. + LatticeT *getLattice(LatticeAnchor anchor) override { + return getOrCreate(anchor); } /// Set the dense lattice at control flow exit point (after the terminator) diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index d4a5472cfde8688..f99eae379596b65 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -50,7 +50,7 @@ class IntegerRangeAnalysis /// At an entry point, we cannot reason about interger value ranges. void setToEntryState(IntegerValueRangeLattice *lattice) override { propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange( - lattice->getPoint()))); + lattice->getAnchor()))); } /// Visit an operation. Invoke the transfer function on each operation that diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 89726ae3a855c8f..933790b4f2a6eb6 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState { /// Lattices can only be created for values. AbstractSparseLattice(Value value) : AnalysisState(value) {} - /// Return the program point this lattice is located at. - Value getPoint() const { return AnalysisState::getPoint().get(); } + /// Return the value this lattice is located at. + Value getAnchor() const { return AnalysisState::getAnchor().get(); } /// Join the information contained in 'rhs' into this lattice. Returns /// if the value of the lattice changed. @@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice { public: using AbstractSparseLattice::AbstractSparseLattice; - /// Return the program point this lattice is located at. - Value getPoint() const { return point.get(); } + /// Return the value this lattice is located at. + Value getAnchor() const { return anchor.get(); } /// Return the value held by this lattice. This requires that the value is /// initialized. diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index 2580ec28b51902a..b0450ecdbd99b8d 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -49,79 +49,93 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { /// Forward declare the analysis state class. class AnalysisState; +/// Program point represents a specific location in the execution of a program. +/// A sequence of program points can be combined into a control flow graph. +struct ProgramPoint : public PointerUnion { + using ParentTy = PointerUnion; + /// Inherit constructors. + using ParentTy::PointerUnion; + /// Allow implicit conversion from the parent type. + ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {} + /// Allow implicit conversions from operation wrappers. + /// TODO: For Windows only. Find a better solution. + template ::value && + !std::is_same::value>> + ProgramPoint(OpT op) : ParentTy(op) {} + + /// Print the program point. + void print(raw_ostream &os) const; +}; + //===----------------------------------------------------------------------===// -// GenericProgramPoint +// GenericLatticeAnchor //===----------------------------------------------------------------------===// -/// Abstract class for generic program points. In classical data-flow analysis, -/// programs points represent positions in a program to which lattice elements +/// Abstract class for generic lattice anchor. In classical data-flow analysis, +/// lattice anchor represent positions in a program to which lattice elements /// are attached. In sparse data-flow analysis, these can be SSA values, and in /// dense data-flow analysis, these are the program points before and after /// every operation. /// -/// In the general MLIR data-flow analysis framework, program points are an -/// extensible concept. Program points are uniquely identifiable objects to -/// which analysis states can be attached. The semantics of program points are -/// defined by the analyses that specify their transfer functions. -/// -/// Program points are implemented using MLIR's storage uniquer framework and +/// Lattice anchor are implemented using MLIR's storage uniquer framework and /// type ID system to provide RTTI. -class GenericProgramPoint : public StorageUniquer::BaseStorage { +class GenericLatticeAnchor : public StorageUniquer::BaseStorage { public: - virtual ~GenericProgramPoint(); + virtual ~GenericLatticeAnchor(); - /// Get the abstract program point's type identifier. + /// Get the abstract lattice anchor's type identifier. TypeID getTypeID() const { return typeID; } - /// Get a derived source location for the program point. + /// Get a derived source location for the lattice anchor. virtual Location getLoc() const = 0; - /// Print the program point. + /// Print the lattice anchor. virtual void print(raw_ostream &os) const = 0; protected: - /// Create an abstract program point with type identifier. - explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {} + /// Create an abstract lattice anchor with type identifier. + explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {} private: - /// The type identifier of the program point. + /// The type identifier of the lattice anchor. TypeID typeID; }; //===----------------------------------------------------------------------===// -// GenericProgramPointBase +// GenericLatticeAnchorBase //===----------------------------------------------------------------------===// -/// Base class for generic program points based on a concrete program point +/// Base class for generic lattice anchor based on a concrete lattice anchor /// type and a content key. This class defines the common methods required for /// operability with the storage uniquer framework. /// -/// The provided key type uniquely identifies the concrete program point +/// The provided key type uniquely identifies the concrete lattice anchor /// instance and are the data members of the class. template -class GenericProgramPointBase : public GenericProgramPoint { +class GenericLatticeAnchorBase : public GenericLatticeAnchor { public: /// The concrete key type used by the storage uniquer. This class is uniqued /// by its contents. using KeyTy = Value; /// Alias for the base class. - using Base = GenericProgramPointBase; + using Base = GenericLatticeAnchorBase; - /// Construct an instance of the program point using the provided value and + /// Construct an instance of the lattice anchor using the provided value and /// the type ID of the concrete type. template - explicit GenericProgramPointBase(ValueT &&value) - : GenericProgramPoint(TypeID::get()), + explicit GenericLatticeAnchorBase(ValueT &&value) + : GenericLatticeAnchor(TypeID::get()), value(std::forward(value)) {} - /// Get a uniqued instance of this program point class with the given + /// Get a uniqued instance of this lattice anchor class with the given /// arguments. template static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { return uniquer.get(/*initFn=*/{}, std::forward(args)...); } - /// Allocate space for a program point and construct it in-place. + /// Allocate space for a lattice anchor and construct it in-place. template static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, ValueT &&value) { @@ -129,46 +143,48 @@ class GenericProgramPointBase : public GenericProgramPoint { ConcreteT(std::forward(value)); } - /// Two program points are equal if their values are equal. + /// Two lattice anchors are equal if their values are equal. bool operator==(const Value &value) const { return this->value == value; } /// Provide LLVM-style RTTI using type IDs. - static bool classof(const GenericProgramPoint *point) { + static bool classof(const GenericLatticeAnchor *point) { return point->getTypeID() == TypeID::get(); } - /// Get the contents of the program point. + /// Get the contents of the lattice anchor. const Value &getValue() const { return value; } private: - /// The program point value. + /// The lattice anchor value. Value value; }; //===----------------------------------------------------------------------===// -// ProgramPoint +// LatticeAnchor //===----------------------------------------------------------------------===// -/// Fundamental IR components are supported as first-class program points. -struct ProgramPoint - : public PointerUnion { - using ParentTy = - PointerUnion; +/// Fundamental IR components are supported as first-class lattice anchor. +struct LatticeAnchor + : public PointerUnion { + using ParentTy = PointerUnion; /// Inherit constructors. using ParentTy::PointerUnion; /// Allow implicit conversion from the parent type. - ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {} + LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {} /// Allow implicit conversions from operation wrappers. /// TODO: For Windows only. Find a better solution. template ::value && !std::is_same::value>> - ProgramPoint(OpT op) : ParentTy(op) {} + LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {} - /// Print the program point. + LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {} + LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {} + + /// Print the lattice anchor. void print(raw_ostream &os) const; - /// Get the source location of the program point. + /// Get the source location of the lattice anchor. Location getLoc() const; }; @@ -207,8 +223,8 @@ class DataFlowConfig { /// The general data-flow analysis solver. This class is responsible for /// orchestrating child data-flow analyses, running the fixed-point iteration -/// algorithm, managing analysis state and program point memory, and tracking -/// dependencies between analyses, program points, and analysis states. +/// algorithm, managing analysis state and lattice anchor memory, and tracking +/// dependencies between analyses, lattice anchor, and analysis states. /// /// Steps to run a data-flow analysis: /// @@ -232,32 +248,33 @@ class DataFlowSolver { /// operation and run the analysis until fixpoint. LogicalResult initializeAndRun(Operation *top); - /// Lookup an analysis state for the given program point. Returns null if one + /// Lookup an analysis state for the given lattice anchor. Returns null if one /// does not exist. - template - const StateT *lookupState(PointT point) const { - auto it = analysisStates.find({ProgramPoint(point), TypeID::get()}); + template + const StateT *lookupState(AnchorT anchor) const { + auto it = + analysisStates.find({LatticeAnchor(anchor), TypeID::get()}); if (it == analysisStates.end()) return nullptr; return static_cast(it->second.get()); } - /// Erase any analysis state associated with the given program point. - template - void eraseState(PointT point) { - ProgramPoint pp(point); + /// Erase any analysis state associated with the given lattice anchor. + template + void eraseState(AnchorT anchor) { + LatticeAnchor la(anchor); for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) { - if (it->first.first == pp) + if (it->first.first == la) analysisStates.erase(it); } } - /// Get a uniqued program point instance. If one is not present, it is + /// Get a uniqued lattice anchor instance. If one is not present, it is /// created with the provided arguments. - template - PointT *getProgramPoint(Args &&...args) { - return PointT::get(uniquer, std::forward(args)...); + template + AnchorT *getLatticeAnchor(Args &&...args) { + return AnchorT::get(uniquer, std::forward(args)...); } /// A work item on the solver queue is a program point, child analysis pair. @@ -267,10 +284,10 @@ class DataFlowSolver { /// Push a work item onto the worklist. void enqueue(WorkItem item) { worklist.push(std::move(item)); } - /// Get the state associated with the given program point. If it does not + /// Get the state associated with the given lattice anchor. If it does not /// exist, create an uninitialized state. - template - StateT *getOrCreateState(PointT point); + template + StateT *getOrCreateState(AnchorT anchor); /// Propagate an update to an analysis state if it changed by pushing /// dependent work items to the back of the queue. @@ -291,13 +308,13 @@ class DataFlowSolver { /// Type-erased instances of the children analyses. SmallVector> childAnalyses; - /// The storage uniquer instance that owns the memory of the allocated program - /// points. + /// The storage uniquer instance that owns the memory of the allocated lattice + /// anchors StorageUniquer uniquer; - /// A type-erased map of program points to associated analysis states for - /// first-class program points. - DenseMap, std::unique_ptr> + /// A type-erased map of lattice anchors to associated analysis states for + /// first-class lattice anchors. + DenseMap, std::unique_ptr> analysisStates; /// Allow the base child analysis class to access the internals of the solver. @@ -309,13 +326,13 @@ class DataFlowSolver { //===----------------------------------------------------------------------===// /// Base class for generic analysis states. Analysis states contain data-flow -/// information that are attached to program points and which evolve as the +/// information that are attached to lattice anchors and which evolve as the /// analysis iterates. /// /// This class places no restrictions on the semantics of analysis states beyond /// these requirements. /// -/// 1. Querying the state of a program point prior to visiting that point +/// 1. Querying the state of a lattice anchor prior to visiting that anchor /// results in uninitialized state. Analyses must be aware of unintialized /// states. /// 2. Analysis states can reach fixpoints, where subsequent updates will never @@ -326,20 +343,20 @@ class AnalysisState { public: virtual ~AnalysisState(); - /// Create the analysis state at the given program point. - AnalysisState(ProgramPoint point) : point(point) {} + /// Create the analysis state at the given lattice anchor. + AnalysisState(LatticeAnchor anchor) : anchor(anchor) {} - /// Returns the program point this state is located at. - ProgramPoint getPoint() const { return point; } + /// Returns the lattice anchor this state is located at. + LatticeAnchor getAnchor() const { return anchor; } /// Print the contents of the analysis state. virtual void print(raw_ostream &os) const = 0; LLVM_DUMP_METHOD void dump() const; - /// Add a dependency to this analysis state on a program point and an + /// Add a dependency to this analysis state on a lattice anchor and an /// analysis. If this state is updated, the analysis will be invoked on the - /// given program point again (in onUpdate()). - void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis); + /// given lattice anchor again (in onUpdate()). + void addDependency(ProgramPoint point, DataFlowAnalysis *analysis); protected: /// This function is called by the solver when the analysis state is updated @@ -351,8 +368,8 @@ class AnalysisState { solver->enqueue(item); } - /// The program point to which the state belongs. - ProgramPoint point; + /// The lattice anchor to which the state belongs. + LatticeAnchor anchor; #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// When compiling with debugging, keep a name for the analysis state. @@ -361,8 +378,8 @@ class AnalysisState { private: /// The dependency relations originating from this analysis state. An entry - /// `state -> (analysis, point)` is created when `analysis` queries `state` - /// when updating `point`. + /// `state -> (analysis, anchor)` is created when `analysis` queries `state` + /// when updating `anchor`. /// /// When this state is updated, all dependent child analysis invocations are /// pushed to the back of the queue. Use a `SetVector` to keep the analysis @@ -403,7 +420,7 @@ class DataFlowAnalysis { explicit DataFlowAnalysis(DataFlowSolver &solver); /// Initialize the analysis from the provided top-level operation by building - /// an initial dependency graph between all program points of interest. This + /// an initial dependency graph between all lattice anchors of interest. This /// can be implemented by calling `visit` on all program points of interest /// below the top-level operation. /// @@ -432,39 +449,39 @@ class DataFlowAnalysis { virtual LogicalResult visit(ProgramPoint point) = 0; protected: - /// Create a dependency between the given analysis state and program point + /// Create a dependency between the given analysis state and lattice anchor /// on this analysis. void addDependency(AnalysisState *state, ProgramPoint point); /// Propagate an update to a state if it changed. void propagateIfChanged(AnalysisState *state, ChangeResult changed); - /// Register a custom program point class. - template - void registerPointKind() { - solver.uniquer.registerParametricStorageType(); + /// Register a custom lattice anchor class. + template + void registerAnchorKind() { + solver.uniquer.registerParametricStorageType(); } - /// Get or create a custom program point. - template - PointT *getProgramPoint(Args &&...args) { - return solver.getProgramPoint(std::forward(args)...); + /// Get or create a custom lattice anchor. + template + AnchorT *getLatticeAnchor(Args &&...args) { + return solver.getLatticeAnchor(std::forward(args)...); } - /// Get the analysis state associated with the program point. The returned + /// Get the analysis state associated with the lattice anchor. The returned /// state is expected to be "write-only", and any updates need to be /// propagated by `propagateIfChanged`. - template - StateT *getOrCreate(PointT point) { - return solver.getOrCreateState(point); + template + StateT *getOrCreate(AnchorT anchor) { + return solver.getOrCreateState(anchor); } /// Get a read-only analysis state for the given point and create a dependency /// on `dependent`. If the return state is updated elsewhere, this analysis is /// re-invoked on the dependent. - template - const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) { - StateT *state = getOrCreate(point); + template + const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) { + StateT *state = getOrCreate(anchor); addDependency(state, dependent); return state; } @@ -494,12 +511,12 @@ AnalysisT *DataFlowSolver::load(Args &&...args) { return static_cast(childAnalyses.back().get()); } -template -StateT *DataFlowSolver::getOrCreateState(PointT point) { +template +StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) { std::unique_ptr &state = - analysisStates[{ProgramPoint(point), TypeID::get()}]; + analysisStates[{LatticeAnchor(anchor), TypeID::get()}]; if (!state) { - state = std::unique_ptr(new StateT(point)); + state = std::unique_ptr(new StateT(anchor)); #if LLVM_ENABLE_ABI_BREAKING_CHECKS state->debugName = llvm::getTypeName(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -512,20 +529,32 @@ inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) { return os; } -inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) { - point.print(os); +inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) { + anchor.print(os); return os; } } // end namespace mlir namespace llvm { -/// Allow hashing of program points. +/// Allow hashing of lattice anchors and program points. +template <> +struct DenseMapInfo + : public DenseMapInfo {}; + template <> struct DenseMapInfo : public DenseMapInfo {}; // Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + template struct CastInfo : public CastInfo {}; @@ -534,6 +563,11 @@ template struct CastInfo : public CastInfo {}; +/// Allow stealing the low bits of a ProgramPoint. +template <> +struct PointerLikeTypeTraits + : public PointerLikeTypeTraits {}; + } // end namespace llvm #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index fab2bd83888da8d..532480b6fad57d9 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -46,17 +46,20 @@ void Executable::print(raw_ostream &os) const { void Executable::onUpdate(DataFlowSolver *solver) const { AnalysisState::onUpdate(solver); - if (auto *block = llvm::dyn_cast_if_present(point)) { - // Re-invoke the analyses on the block itself. - for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({block, analysis}); - // Re-invoke the analyses on all operations in the block. - for (DataFlowAnalysis *analysis : subscribers) - for (Operation &op : *block) - solver->enqueue({&op, analysis}); - } else if (auto *programPoint = llvm::dyn_cast_if_present(point)) { + if (ProgramPoint pp = llvm::dyn_cast_if_present(anchor)) { + if (Block *block = llvm::dyn_cast_if_present(pp)) { + // Re-invoke the analyses on the block itself. + for (DataFlowAnalysis *analysis : subscribers) + solver->enqueue({block, analysis}); + // Re-invoke the analyses on all operations in the block. + for (DataFlowAnalysis *analysis : subscribers) + for (Operation &op : *block) + solver->enqueue({&op, analysis}); + } + } else if (auto *latticeAnchor = + llvm::dyn_cast_if_present(anchor)) { // Re-invoke the analysis on the successor block. - if (auto *edge = dyn_cast(programPoint)) { + if (auto *edge = dyn_cast(latticeAnchor)) { for (DataFlowAnalysis *analysis : subscribers) solver->enqueue({edge->getTo(), analysis}); } @@ -114,7 +117,7 @@ void CFGEdge::print(raw_ostream &os) const { DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) : DataFlowAnalysis(solver) { - registerPointKind(); + registerAnchorKind(); } LogicalResult DeadCodeAnalysis::initialize(Operation *top) { @@ -218,7 +221,8 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { auto *state = getOrCreate(to); propagateIfChanged(state, state->setToLive()); - auto *edgeState = getOrCreate(getProgramPoint(from, to)); + auto *edgeState = + getOrCreate(getLatticeAnchor(from, to)); propagateIfChanged(edgeState, edgeState->setToLive()); } @@ -234,9 +238,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { if (point.is()) return success(); - auto *op = llvm::dyn_cast_if_present(point); - if (!op) - return emitError(point.getLoc(), "unknown program point kind"); + auto *op = point.get(); // If the parent block is not executable, there is nothing to do. if (!getOrCreate(op->getBlock())->isLive()) diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 33c877f78f4bf66..37f4ceaaa56cee7 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -47,10 +47,7 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) { LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) { if (auto *op = llvm::dyn_cast_if_present(point)) return processOperation(op); - else if (auto *block = llvm::dyn_cast_if_present(point)) - visitBlock(block); - else - return failure(); + visitBlock(point.get()); return success(); } @@ -180,7 +177,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { // Skip control edges that aren't executable. Block *predecessor = *it; if (!getOrCreateFor( - block, getProgramPoint(predecessor, block)) + block, getLatticeAnchor(predecessor, block)) ->isLive()) continue; @@ -248,8 +245,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( const AbstractDenseLattice * AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, - ProgramPoint point) { - AbstractDenseLattice *state = getLattice(point); + LatticeAnchor anchor) { + AbstractDenseLattice *state = getLattice(anchor); addDependency(state, dependent); return state; } @@ -279,10 +276,7 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) { LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) { if (auto *op = llvm::dyn_cast_if_present(point)) return processOperation(op); - else if (auto *block = llvm::dyn_cast_if_present(point)) - visitBlock(block); - else - return failure(); + visitBlock(point.get()); return success(); } @@ -424,7 +418,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // Meet the state with the state before block's successors. for (Block *successor : block->getSuccessors()) { if (!getOrCreateFor(block, - getProgramPoint(block, successor)) + getLatticeAnchor(block, successor)) ->isLive()) continue; @@ -474,8 +468,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( const AbstractDenseLattice * AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, - ProgramPoint point) { - AbstractDenseLattice *state = getLattice(point); + LatticeAnchor anchor) { + AbstractDenseLattice *state = getLattice(anchor); addDependency(state, dependent); return state; } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 35d38ea02d71629..9a95f172d5df48e 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -42,7 +42,7 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { // If the integer range can be narrowed to a constant, update the constant // value of the SSA value. std::optional constant = getValue().getValue().getConstantValue(); - auto value = point.get(); + auto value = anchor.get(); auto *cv = solver->getOrCreateState>(value); if (!constant) return solver->propagateIfChanged( diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index d47d5fec8a9a6a9..4a73f21a18aae74 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -34,7 +34,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { AnalysisState::onUpdate(solver); // Push all users of the value to the queue. - for (Operation *user : point.get().getUsers()) + for (Operation *user : anchor.get().getUsers()) for (DataFlowAnalysis *analysis : useDefSubscribers) solver->enqueue({user, analysis}); } @@ -46,7 +46,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis( DataFlowSolver &solver) : DataFlowAnalysis(solver) { - registerPointKind(); + registerAnchorKind(); } LogicalResult @@ -86,10 +86,7 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) { if (Operation *op = llvm::dyn_cast_if_present(point)) return visitOperation(op); - else if (Block *block = llvm::dyn_cast_if_present(point)) - visitBlock(block); - else - return failure(); + visitBlock(point.get()); return success(); } @@ -217,7 +214,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // If the edge from the predecessor block to the current block is not live, // bail out. auto *edgeExecutable = - getOrCreate(getProgramPoint(predecessor, block)); + getOrCreate(getLatticeAnchor(predecessor, block)); edgeExecutable->blockContentSubscribe(this); if (!edgeExecutable->isLive()) continue; @@ -324,7 +321,7 @@ void AbstractSparseForwardDataFlowAnalysis::join( AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( DataFlowSolver &solver, SymbolTableCollection &symbolTable) : DataFlowAnalysis(solver), symbolTable(symbolTable) { - registerPointKind(); + registerAnchorKind(); } LogicalResult @@ -355,14 +352,10 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { if (Operation *op = llvm::dyn_cast_if_present(point)) return visitOperation(op); - else if (llvm::dyn_cast_if_present(point)) - // For backward dataflow, we don't have to do any work for the blocks - // themselves. CFG edges between blocks are processed by the BranchOp - // logic in `visitOperation`, and entry blocks for functions are tied - // to the CallOp arguments by visitOperation. - return success(); - else - return failure(); + // For backward dataflow, we don't have to do any work for the blocks + // themselves. CFG edges between blocks are processed by the BranchOp + // logic in `visitOperation`, and entry blocks for functions are tied + // to the CallOp arguments by visitOperation. return success(); } diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index d0e827aa1c2b64e..a65ddc13143bae3 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -26,10 +26,10 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// GenericProgramPoint +// GenericLatticeAnchor //===----------------------------------------------------------------------===// -GenericProgramPoint::~GenericProgramPoint() = default; +GenericLatticeAnchor::~GenericLatticeAnchor() = default; //===----------------------------------------------------------------------===// // AnalysisState @@ -44,7 +44,7 @@ void AnalysisState::addDependency(ProgramPoint dependent, DATAFLOW_DEBUG({ if (inserted) { llvm::dbgs() << "Creating dependency between " << debugName << " of " - << point << "\nand " << debugName << " on " << dependent + << anchor << "\nand " << debugName << " on " << dependent << "\n"; } }); @@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent, void AnalysisState::dump() const { print(llvm::errs()); } //===----------------------------------------------------------------------===// -// ProgramPoint +// LatticeAnchor //===----------------------------------------------------------------------===// void ProgramPoint::print(raw_ostream &os) const { @@ -61,23 +61,36 @@ void ProgramPoint::print(raw_ostream &os) const { os << ""; return; } - if (auto *programPoint = llvm::dyn_cast(*this)) - return programPoint->print(os); - if (auto *op = llvm::dyn_cast(*this)) + if (Operation *op = llvm::dyn_cast(*this)) { return op->print(os, OpPrintingFlags().skipRegions()); - if (auto value = llvm::dyn_cast(*this)) - return value.print(os, OpPrintingFlags().skipRegions()); + } return get()->print(os); } -Location ProgramPoint::getLoc() const { - if (auto *programPoint = llvm::dyn_cast(*this)) - return programPoint->getLoc(); - if (auto *op = llvm::dyn_cast(*this)) - return op->getLoc(); +void LatticeAnchor::print(raw_ostream &os) const { + if (isNull()) { + os << ""; + return; + } + if (auto *LatticeAnchor = llvm::dyn_cast(*this)) + return LatticeAnchor->print(os); + if (auto value = llvm::dyn_cast(*this)) { + return value.print(os, OpPrintingFlags().skipRegions()); + } + + return get().print(os); +} + +Location LatticeAnchor::getLoc() const { + if (auto *LatticeAnchor = llvm::dyn_cast(*this)) + return LatticeAnchor->getLoc(); if (auto value = llvm::dyn_cast(*this)) return value.getLoc(); - return get()->getParent()->getLoc(); + + ProgramPoint pp = get(); + if (auto *op = llvm::dyn_cast(pp)) + return op->getLoc(); + return pp.get()->getParent()->getLoc(); } //===----------------------------------------------------------------------===// @@ -117,7 +130,7 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state, ChangeResult changed) { if (changed == ChangeResult::Change) { DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->point << "\n" + << " of " << state->anchor << "\n" << "Value: " << *state << "\n"); state->onUpdate(this); } diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp index 90973af9c2cf5df..d02efaaa3fe320b 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -40,7 +40,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op, pred->printAsOperand(os); os << " = "; auto *live = solver.lookupState( - solver.getProgramPoint(pred, &block)); + solver.getLatticeAnchor(pred, &block)); if (live) os << *live; else diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h index 57fe0ca458de217..86eb8651cb90c18 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h @@ -206,7 +206,7 @@ class UnderlyingValueAnalysis /// At an entry point, the underlying value of a value is itself. void setToEntryState(UnderlyingValueLattice *lattice) override { propagateIfChanged(lattice, - lattice->join(UnderlyingValue{lattice->getPoint()})); + lattice->join(UnderlyingValue{lattice->getAnchor()})); } /// Look for the most underlying value of a value. diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp index b6b33182440cf42..9573ec1d1432574 100644 --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -115,15 +115,11 @@ LogicalResult FooAnalysis::initialize(Operation *top) { } LogicalResult FooAnalysis::visit(ProgramPoint point) { - if (auto *op = llvm::dyn_cast_if_present(point)) { + if (auto *op = llvm::dyn_cast_if_present(point)) visitOperation(op); - return success(); - } - if (auto *block = llvm::dyn_cast_if_present(point)) { - visitBlock(block); - return success(); - } - return emitError(point.getLoc(), "unknown point kind"); + else + visitBlock(point.get()); + return success(); } void FooAnalysis::visitBlock(Block *block) {