From 44b753d0fe67a9265792ff8405fcf5c0bb323fad Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Tue, 9 Apr 2024 15:33:31 -0700 Subject: [PATCH] Revert "[LowerFirReg] Reimplement the mux reachability analysis (#6709)" This reverts commit ac8f7767e1f02f4a2b0fc6c9d8b518551a8d0f33. --- lib/Conversion/SeqToSV/FirRegLowering.cpp | 102 ++++++---------------- lib/Conversion/SeqToSV/FirRegLowering.h | 67 ++------------ test/Dialect/Seq/firreg.mlir | 25 ++---- 3 files changed, 40 insertions(+), 154 deletions(-) diff --git a/lib/Conversion/SeqToSV/FirRegLowering.cpp b/lib/Conversion/SeqToSV/FirRegLowering.cpp index 9c500416efc1..39765d5f2e5d 100644 --- a/lib/Conversion/SeqToSV/FirRegLowering.cpp +++ b/lib/Conversion/SeqToSV/FirRegLowering.cpp @@ -12,7 +12,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" -#include using namespace circt; using namespace hw; @@ -21,69 +20,34 @@ using llvm::MapVector; #define DEBUG_TYPE "lower-seq-firreg" -std::function OpUserInfo::opAllowsReachability = - [](const Operation *op) -> bool { - return (isa(op)); -}; - -bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp, - comb::MuxOp muxOp) { - return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) { - if (!OpUserInfo::opAllowsReachability(user)) - return false; - buildReachabilityFrom(user); - return reachableMuxes[user].contains(muxOp); - }); -} - -void ReachableMuxes::buildReachabilityFrom(Operation *startNode) { - // This is a backward dataflow analysis. - // First build a graph rooted at the `startNode`. Every user of an operation - // that does not block the reachability is a child node. Then, the ops that - // are reachable from a node is computed as the union of the Reachability of - // all its child nodes. - // The dataflow can be expressed as, for all child in the Children(node) - // Reachability(node) = node + Union{Reachability(child)} - if (visited.contains(startNode)) - return; - - // The stack to record enough information for an iterative post-order - // traversal. - llvm::SmallVector stk; +// Reimplemented from SliceAnalysis to use a worklist rather than recursion and +// non-insert ordered set. +static void +getForwardSliceSimple(Operation *root, + llvm::DenseSet &forwardSlice, + llvm::function_ref filter = nullptr) { + SmallVector worklist({root}); - stk.emplace_back(startNode); - - while (!stk.empty()) { - auto &info = stk.back(); - Operation *currentNode = info.op; - - // Node is being visited for the first time. - if (info.getAndSetUnvisited()) - visited.insert(currentNode); - - if (info.userIter != info.userEnd) { - Operation *child = *info.userIter; - ++info.userIter; - if (!visited.contains(child)) - stk.emplace_back(child); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); - } else { // All children of the node have been visited - // Any op is reachable from itself. - reachableMuxes[currentNode].insert(currentNode); + if (!op) + continue; - for (auto *childOp : llvm::make_filter_range( - info.op->getUsers(), OpUserInfo::opAllowsReachability)) { - reachableMuxes[currentNode].insert(childOp); - // Propagate the reachability backwards from m to currentNode. - auto iter = reachableMuxes.find(childOp); - assert(iter != reachableMuxes.end()); + if (filter && !filter(op)) + continue; - // Add all the mux that was reachable from childOp, to currentNode. - reachableMuxes[currentNode].insert(iter->getSecond().begin(), - iter->getSecond().end()); - } - stk.pop_back(); - } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &blockOp : block) + if (forwardSlice.insert(&blockOp).second) + worklist.push_back(&blockOp); + for (Value result : op->getResults()) + for (Operation *userOp : result.getUsers()) + if (forwardSlice.insert(userOp).second) + worklist.push_back(userOp); + + forwardSlice.insert(op); } } @@ -106,17 +70,6 @@ void FirRegLowering::addToIfBlock(OpBuilder &builder, Value cond, } } -FirRegLowering::FirRegLowering(TypeConverter &typeConverter, - hw::HWModuleOp module, - bool disableRegRandomization, - bool emitSeparateAlwaysBlocks) - : typeConverter(typeConverter), module(module), - disableRegRandomization(disableRegRandomization), - emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks) { - - reachableMuxes = std::make_unique(module); -} - void FirRegLowering::lower() { // Find all registers to lower in the module. auto regs = module.getOps(); @@ -405,6 +358,10 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term, // want to create if/else structure for logic unrelated to the register's // enable. auto firReg = term.getDefiningOp(); + DenseSet regMuxFanout; + getForwardSliceSimple(firReg, regMuxFanout, [&](Operation *op) { + return op == firReg || !isa(op); + }); SmallVector> worklist; auto addToWorklist = [&](Value reg, Value term, Value next) { @@ -432,8 +389,7 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term, // If this is a two-state mux within the fanout from the register, we use // if/else structure for proper enable inference. auto mux = next.getDefiningOp(); - if (mux && mux.getTwoState() && - reachableMuxes->isMuxReachableFrom(firReg, mux)) { + if (mux && mux.getTwoState() && regMuxFanout.contains(mux)) { addToIfBlock( builder, mux.getCond(), [&]() { addToWorklist(reg, term, mux.getTrueValue()); }, diff --git a/lib/Conversion/SeqToSV/FirRegLowering.h b/lib/Conversion/SeqToSV/FirRegLowering.h index a78c406e53b3..a65c00f2b11d 100644 --- a/lib/Conversion/SeqToSV/FirRegLowering.h +++ b/lib/Conversion/SeqToSV/FirRegLowering.h @@ -10,82 +10,26 @@ #ifndef CONVERSION_SEQTOSV_FIRREGLOWERING_H #define CONVERSION_SEQTOSV_FIRREGLOWERING_H -#include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/SV/SVOps.h" #include "circt/Dialect/Seq/SeqOps.h" #include "circt/Support/LLVM.h" #include "circt/Support/Namespace.h" #include "circt/Support/SymCache.h" -#include "mlir/IR/Visitors.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include -#include -#include namespace circt { - -using namespace hw; -// This class computes the set of muxes that are reachable from an op. -// The heuristic propagates the reachability only through the 3 ops, mux, -// array_create and array_get. All other ops block the reachability. -// This analysis is built lazily on every query. -// The query: is a mux is reachable from a reg, results in a DFS traversal -// of the IR rooted at the register. This traversal is completed and the -// result is cached in a Map, for faster retrieval on any future query of any -// op in this subgraph. -class ReachableMuxes { -public: - ReachableMuxes(HWModuleOp m) : module(m) {} - - bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp); - -private: - void buildReachabilityFrom(Operation *startNode); - HWModuleOp module; - llvm::DenseMap> reachableMuxes; - llvm::SmallPtrSet visited; -}; - -// The op and its users information that needs to be tracked on the stack -// for an iterative DFS traversal. -struct OpUserInfo { - Operation *op; - using ValidUsersIterator = - llvm::filter_iterator>; - - ValidUsersIterator userIter, userEnd; - static std::function opAllowsReachability; - - OpUserInfo(Operation *op) - : op(op), userIter(op->getUsers().begin(), op->getUsers().end(), - opAllowsReachability), - userEnd(op->getUsers().end(), op->getUsers().end(), - opAllowsReachability) {} - - bool getAndSetUnvisited() { - if (unvisited) { - unvisited = false; - return true; - } - return false; - } - -private: - bool unvisited = true; -}; - /// Lower FirRegOp to `sv.reg` and `sv.always`. class FirRegLowering { public: FirRegLowering(TypeConverter &typeConverter, hw::HWModuleOp module, bool disableRegRandomization = false, - bool emitSeparateAlwaysBlocks = false); + bool emitSeparateAlwaysBlocks = false) + : typeConverter(typeConverter), module(module), + disableRegRandomization(disableRegRandomization), + emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks){}; void lower(); + bool needsRegRandomization() const { return needsRandom; } unsigned numSubaccessRestored = 0; @@ -143,7 +87,6 @@ class FirRegLowering { llvm::SmallDenseMap constantCache; llvm::SmallDenseMap, Value> arrayIndexCache; - std::unique_ptr reachableMuxes; TypeConverter &typeConverter; hw::HWModuleOp module; diff --git a/test/Dialect/Seq/firreg.mlir b/test/Dialect/Seq/firreg.mlir index 16cfeb5f58b9..4557ff2507ae 100644 --- a/test/Dialect/Seq/firreg.mlir +++ b/test/Dialect/Seq/firreg.mlir @@ -290,12 +290,16 @@ hw.module private @InitReg1(in %clock: !seq.clock, in %reset: i1, in %io_d: i32, // COMMON-NEXT: %5 = comb.add %3, %4 : i33 // COMMON-NEXT: %6 = comb.extract %5 from 1 : (i33) -> i32 // COMMON-NEXT: %7 = comb.mux bin %io_en, %io_d, %6 : i32 - // COMMON-NEXT: sv.always posedge %clock, posedge %reset { + // COMMON-NEXT: sv.always posedge %clock, posedge %reset { // COMMON-NEXT: sv.if %reset { // COMMON-NEXT: sv.passign %reg, %c0_i32 : i32 // COMMON-NEXT: sv.passign %reg3, %c1_i32 : i32 // COMMON-NEXT: } else { - // COMMON-NEXT: sv.passign %reg, %7 : i32 + // COMMON-NEXT: sv.if %io_en { + // COMMON-NEXT: sv.passign %reg, %io_d : i32 + // COMMON-NEXT: } else { + // COMMON-NEXT: sv.passign %reg, %6 : i32 + // COMMON-NEXT: } // COMMON-NEXT: sv.passign %reg3, %2 : i32 // COMMON-NEXT: } // COMMON-NEXT: } @@ -911,20 +915,3 @@ hw.module @RegMuxInlining3(in %clock: !seq.clock, in %c: i1, out out: i8) { %0 = comb.mux bin %c, %r2, %r3 : i8 hw.output %r1 : i8 } - - // CHECK-LABEL: hw.module @SharedMux - hw.module @SharedMux(in %clock: !seq.clock, in %cond : i1, out o: i2){ - %mux = comb.mux bin %cond, %r1, %r2 : i2 - %r1 = seq.firreg %mux clock %clock : i2 - %r2 = seq.firreg %mux clock %clock : i2 - hw.output %r2: i2 - //CHECK: %r1 = sv.reg : !hw.inout - //CHECK: %[[V1:.+]] = sv.read_inout %r1 : !hw.inout - //CHECK: %r2 = sv.reg : !hw.inout - //CHECK: %[[V2:.+]] = sv.read_inout %r2 : !hw.inout - //CHECK: sv.always posedge %clock { - //CHECK: sv.if %cond { - //CHECK: sv.passign %r2, %[[V1]] : i2 - //CHECK: } else { - //CHECK: sv.passign %r1, %[[V2]] : i2 -}