Skip to content

Commit

Permalink
[LowerFirReg] Reimplement the mux reachability analysis (#6709)
Browse files Browse the repository at this point in the history
The PR implements a new heuristic to determine if a Mux is reachable from a
 FirReg. In general an operation is reachable from a register if its in the
 fanout of the register. For FirReg lowering, an if/else structure is required
 for proper enable inference, if a mux is within the fanout from the register.
The fanout path can only consist of MuxOp, ArrayGetOp or ArrayCreateOp.
Thus any ops other than MuxOp, ArrayGetOp or ArrayCreateOp block the
 reachability. The analysis is built lazily when its queried and the result is
 cached to avoid redundant traversal of the IR.
  • Loading branch information
prithayan committed Apr 9, 2024
1 parent fbef8b9 commit ac8f776
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 40 deletions.
102 changes: 73 additions & 29 deletions lib/Conversion/SeqToSV/FirRegLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include <cassert>

using namespace circt;
using namespace hw;
Expand All @@ -20,34 +21,69 @@ using llvm::MapVector;

#define DEBUG_TYPE "lower-seq-firreg"

// Reimplemented from SliceAnalysis to use a worklist rather than recursion and
// non-insert ordered set.
static void
getForwardSliceSimple(Operation *root,
llvm::DenseSet<Operation *> &forwardSlice,
llvm::function_ref<bool(Operation *)> filter = nullptr) {
SmallVector<Operation *> worklist({root});
std::function<bool(const Operation *op)> OpUserInfo::opAllowsReachability =
[](const Operation *op) -> bool {
return (isa<comb::MuxOp, ArrayGetOp, ArrayCreateOp>(op));
};

bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp,
comb::MuxOp muxOp) {
return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) {
if (!OpUserInfo::opAllowsReachability(user))
return false;
buildReachabilityFrom(user);
return reachableMuxes[user].contains(muxOp);
});
}

while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
void ReachableMuxes::buildReachabilityFrom(Operation *startNode) {
// This is a backward dataflow analysis.
// First build a graph rooted at the `startNode`. Every user of an operation
// that does not block the reachability is a child node. Then, the ops that
// are reachable from a node is computed as the union of the Reachability of
// all its child nodes.
// The dataflow can be expressed as, for all child in the Children(node)
// Reachability(node) = node + Union{Reachability(child)}
if (visited.contains(startNode))
return;

if (!op)
continue;
// The stack to record enough information for an iterative post-order
// traversal.
llvm::SmallVector<OpUserInfo> stk;

if (filter && !filter(op))
continue;
stk.emplace_back(startNode);

while (!stk.empty()) {
auto &info = stk.back();
Operation *currentNode = info.op;

// Node is being visited for the first time.
if (info.getAndSetUnvisited())
visited.insert(currentNode);

if (info.userIter != info.userEnd) {
Operation *child = *info.userIter;
++info.userIter;
if (!visited.contains(child))
stk.emplace_back(child);

for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice.insert(&blockOp).second)
worklist.push_back(&blockOp);
for (Value result : op->getResults())
for (Operation *userOp : result.getUsers())
if (forwardSlice.insert(userOp).second)
worklist.push_back(userOp);

forwardSlice.insert(op);
} else { // All children of the node have been visited
// Any op is reachable from itself.
reachableMuxes[currentNode].insert(currentNode);

for (auto *childOp : llvm::make_filter_range(
info.op->getUsers(), OpUserInfo::opAllowsReachability)) {
reachableMuxes[currentNode].insert(childOp);
// Propagate the reachability backwards from m to currentNode.
auto iter = reachableMuxes.find(childOp);
assert(iter != reachableMuxes.end());

// Add all the mux that was reachable from childOp, to currentNode.
reachableMuxes[currentNode].insert(iter->getSecond().begin(),
iter->getSecond().end());
}
stk.pop_back();
}
}
}

Expand All @@ -70,6 +106,17 @@ void FirRegLowering::addToIfBlock(OpBuilder &builder, Value cond,
}
}

FirRegLowering::FirRegLowering(TypeConverter &typeConverter,
hw::HWModuleOp module,
bool disableRegRandomization,
bool emitSeparateAlwaysBlocks)
: typeConverter(typeConverter), module(module),
disableRegRandomization(disableRegRandomization),
emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks) {

reachableMuxes = std::make_unique<ReachableMuxes>(module);
}

void FirRegLowering::lower() {
// Find all registers to lower in the module.
auto regs = module.getOps<seq::FirRegOp>();
Expand Down Expand Up @@ -358,10 +405,6 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term,
// want to create if/else structure for logic unrelated to the register's
// enable.
auto firReg = term.getDefiningOp<seq::FirRegOp>();
DenseSet<Operation *> regMuxFanout;
getForwardSliceSimple(firReg, regMuxFanout, [&](Operation *op) {
return op == firReg || !isa<sv::RegOp, seq::FirRegOp, hw::InstanceOp>(op);
});

SmallVector<std::tuple<Block *, Value, Value, Value>> worklist;
auto addToWorklist = [&](Value reg, Value term, Value next) {
Expand Down Expand Up @@ -389,7 +432,8 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term,
// If this is a two-state mux within the fanout from the register, we use
// if/else structure for proper enable inference.
auto mux = next.getDefiningOp<comb::MuxOp>();
if (mux && mux.getTwoState() && regMuxFanout.contains(mux)) {
if (mux && mux.getTwoState() &&
reachableMuxes->isMuxReachableFrom(firReg, mux)) {
addToIfBlock(
builder, mux.getCond(),
[&]() { addToWorklist(reg, term, mux.getTrueValue()); },
Expand Down
67 changes: 62 additions & 5 deletions lib/Conversion/SeqToSV/FirRegLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,82 @@
#ifndef CONVERSION_SEQTOSV_FIRREGLOWERING_H
#define CONVERSION_SEQTOSV_FIRREGLOWERING_H

#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/SV/SVOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/LLVM.h"
#include "circt/Support/Namespace.h"
#include "circt/Support/SymCache.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/IR/ValueRange.h>
#include <stack>
#include <unordered_set>

namespace circt {

using namespace hw;
// This class computes the set of muxes that are reachable from an op.
// The heuristic propagates the reachability only through the 3 ops, mux,
// array_create and array_get. All other ops block the reachability.
// This analysis is built lazily on every query.
// The query: is a mux is reachable from a reg, results in a DFS traversal
// of the IR rooted at the register. This traversal is completed and the
// result is cached in a Map, for faster retrieval on any future query of any
// op in this subgraph.
class ReachableMuxes {
public:
ReachableMuxes(HWModuleOp m) : module(m) {}

bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp);

private:
void buildReachabilityFrom(Operation *startNode);
HWModuleOp module;
llvm::DenseMap<Operation *, llvm::SmallDenseSet<Operation *>> reachableMuxes;
llvm::SmallPtrSet<Operation *, 16> visited;
};

// The op and its users information that needs to be tracked on the stack
// for an iterative DFS traversal.
struct OpUserInfo {
Operation *op;
using ValidUsersIterator =
llvm::filter_iterator<Operation::user_iterator,
std::function<bool(const Operation *)>>;

ValidUsersIterator userIter, userEnd;
static std::function<bool(const Operation *op)> opAllowsReachability;

OpUserInfo(Operation *op)
: op(op), userIter(op->getUsers().begin(), op->getUsers().end(),
opAllowsReachability),
userEnd(op->getUsers().end(), op->getUsers().end(),
opAllowsReachability) {}

bool getAndSetUnvisited() {
if (unvisited) {
unvisited = false;
return true;
}
return false;
}

private:
bool unvisited = true;
};

/// Lower FirRegOp to `sv.reg` and `sv.always`.
class FirRegLowering {
public:
FirRegLowering(TypeConverter &typeConverter, hw::HWModuleOp module,
bool disableRegRandomization = false,
bool emitSeparateAlwaysBlocks = false)
: typeConverter(typeConverter), module(module),
disableRegRandomization(disableRegRandomization),
emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks){};
bool emitSeparateAlwaysBlocks = false);

void lower();

bool needsRegRandomization() const { return needsRandom; }

unsigned numSubaccessRestored = 0;
Expand Down Expand Up @@ -87,6 +143,7 @@ class FirRegLowering {

llvm::SmallDenseMap<APInt, hw::ConstantOp> constantCache;
llvm::SmallDenseMap<std::pair<Value, unsigned>, Value> arrayIndexCache;
std::unique_ptr<ReachableMuxes> reachableMuxes;

TypeConverter &typeConverter;
hw::HWModuleOp module;
Expand Down
25 changes: 19 additions & 6 deletions test/Dialect/Seq/firreg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,12 @@ hw.module private @InitReg1(in %clock: !seq.clock, in %reset: i1, in %io_d: i32,
// COMMON-NEXT: %5 = comb.add %3, %4 : i33
// COMMON-NEXT: %6 = comb.extract %5 from 1 : (i33) -> i32
// COMMON-NEXT: %7 = comb.mux bin %io_en, %io_d, %6 : i32
// COMMON-NEXT: sv.always posedge %clock, posedge %reset {
// COMMON-NEXT: sv.always posedge %clock, posedge %reset {
// COMMON-NEXT: sv.if %reset {
// COMMON-NEXT: sv.passign %reg, %c0_i32 : i32
// COMMON-NEXT: sv.passign %reg3, %c1_i32 : i32
// COMMON-NEXT: } else {
// COMMON-NEXT: sv.if %io_en {
// COMMON-NEXT: sv.passign %reg, %io_d : i32
// COMMON-NEXT: } else {
// COMMON-NEXT: sv.passign %reg, %6 : i32
// COMMON-NEXT: }
// COMMON-NEXT: sv.passign %reg, %7 : i32
// COMMON-NEXT: sv.passign %reg3, %2 : i32
// COMMON-NEXT: }
// COMMON-NEXT: }
Expand Down Expand Up @@ -915,3 +911,20 @@ hw.module @RegMuxInlining3(in %clock: !seq.clock, in %c: i1, out out: i8) {
%0 = comb.mux bin %c, %r2, %r3 : i8
hw.output %r1 : i8
}

// CHECK-LABEL: hw.module @SharedMux
hw.module @SharedMux(in %clock: !seq.clock, in %cond : i1, out o: i2){
%mux = comb.mux bin %cond, %r1, %r2 : i2
%r1 = seq.firreg %mux clock %clock : i2
%r2 = seq.firreg %mux clock %clock : i2
hw.output %r2: i2
//CHECK: %r1 = sv.reg : !hw.inout<i2>
//CHECK: %[[V1:.+]] = sv.read_inout %r1 : !hw.inout<i2>
//CHECK: %r2 = sv.reg : !hw.inout<i2>
//CHECK: %[[V2:.+]] = sv.read_inout %r2 : !hw.inout<i2>
//CHECK: sv.always posedge %clock {
//CHECK: sv.if %cond {
//CHECK: sv.passign %r2, %[[V1]] : i2
//CHECK: } else {
//CHECK: sv.passign %r1, %[[V2]] : i2
}

0 comments on commit ac8f776

Please sign in to comment.