Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
prithayan committed Apr 9, 2024
1 parent b9f26c7 commit 12a9689
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
36 changes: 21 additions & 15 deletions lib/Conversion/SeqToSV/FirRegLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include <cassert>

using namespace circt;
using namespace hw;
Expand All @@ -20,10 +21,15 @@ using llvm::MapVector;

#define DEBUG_TYPE "lower-seq-firreg"

std::function<bool(const Operation *op)> OpUserInfo::opAllowsReachability =
[](const Operation *op) -> bool {
return (isa<comb::MuxOp, ArrayGetOp, ArrayCreateOp>(op));
};

bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp,
comb::MuxOp muxOp) {
return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) {
if (opBlocksReachability(user))
if (!OpUserInfo::opAllowsReachability(user))
return false;
buildReachabilityFrom(user);
return reachableMuxes[user].contains(muxOp);
Expand All @@ -33,16 +39,16 @@ bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp,
void ReachableMuxes::buildReachabilityFrom(Operation *startNode) {
// This is a backward dataflow analysis.
// First build a graph rooted at the `startNode`. Every user of an operation
// that doesnot block the reachability is a child node. Then, the ops that
// that does not block the reachability is a child node. Then, the ops that
// are reachable from a node is computed as the union of the Reachability of
// all its child nodes.
// The dataflow can be expressed as, for all child in the Children(node)
// Reachability(node) = node + Union{Reachability(child)}
if (visited.find(startNode) != visited.end())
if (visited.contains(startNode))
return;

// The stack to record enough information for an iterative post-order
// traversal.

llvm::SmallVector<OpUserInfo> stk;

stk.emplace_back(startNode);
Expand All @@ -52,29 +58,29 @@ void ReachableMuxes::buildReachabilityFrom(Operation *startNode) {
Operation *currentNode = info.op;

// Node is being visited for the first time.
if (info.userIter == info.userRange.begin())
if (info.getAndSetUnvisited())
visited.insert(currentNode);
if (info.getNextValid(info.userIter)) {

if (info.userIter != info.userEnd) {
Operation *child = *info.userIter;
++info.userIter;
if (visited.find(child) == visited.end())
if (!visited.contains(child))
stk.emplace_back(child);

} else { // All children of the node have been visited
// Any op is reachable from itself.
reachableMuxes[currentNode].insert(currentNode);
auto userIterator = info.userRange.begin();
while (info.getNextValid(userIterator)) {
Operation *childOp = *userIterator;

for (auto *childOp : llvm::make_filter_range(
info.op->getUsers(), OpUserInfo::opAllowsReachability)) {
reachableMuxes[currentNode].insert(childOp);
// Propagate the reachability backwards from m to currentNode.
auto iter = reachableMuxes.find(childOp);
assert(iter != reachableMuxes.end());

if (iter != reachableMuxes.end())
reachableMuxes[currentNode].insert(iter->getSecond().begin(),
iter->getSecond().end());

++userIterator;
// Add all the mux that was reachable from childOp, to currentNode.
reachableMuxes[currentNode].insert(iter->getSecond().begin(),
iter->getSecond().end());
}
stk.pop_back();
}
Expand Down
55 changes: 31 additions & 24 deletions lib/Conversion/SeqToSV/FirRegLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "circt/Support/Namespace.h"
#include "circt/Support/SymCache.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/IR/ValueRange.h>
Expand All @@ -32,43 +33,49 @@ using namespace hw;
// array_create and array_get. All other ops block the reachability.
// This analysis is built lazily on every query.
// The query: is a mux is reachable from a reg, results in a DFS traversal
// of the IR rooted at the register. This traversal is completed and the result
// is cached in a Map, for faster retrieval on any future query of any op in
// this subgraph.
// of the IR rooted at the register. This traversal is completed and the
// result is cached in a Map, for faster retrieval on any future query of any
// op in this subgraph.
class ReachableMuxes {
public:
ReachableMuxes(HWModuleOp m) : module(m) {}

bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp);

private:
static inline bool opBlocksReachability(Operation *op) {
return (!isa<comb::MuxOp, ArrayGetOp, ArrayCreateOp>(op));
}
void buildReachabilityFrom(Operation *startNode);
HWModuleOp module;
llvm::DenseMap<Operation *, llvm::SmallDenseSet<Operation *>> reachableMuxes;
llvm::SmallPtrSet<Operation *, 16> visited;
};

// The op and its users information that needs to be tracked on the stack
// for an iterative DFS traversal.
struct OpUserInfo {
Operation *op;
const mlir::ResultRange::user_range userRange;
mlir::ResultRange::user_iterator userIter;

OpUserInfo(Operation *op)
: op(op), userRange(op->getUsers()), userIter(userRange.begin()) {}

// Increments the itertor to the next valid user op and returns false if
// the iterator reaches the end of the range.
auto getNextValid(mlir::ResultRange::user_iterator &iter) const {
for (; iter != userRange.end(); ++iter)
if (!opBlocksReachability(*iter))
return true;
return false;
// The op and its users information that needs to be tracked on the stack
// for an iterative DFS traversal.
struct OpUserInfo {
Operation *op;
using ValidUsersIterator =
llvm::filter_iterator<Operation::user_iterator,
std::function<bool(const Operation *)>>;

ValidUsersIterator userIter, userEnd;
static std::function<bool(const Operation *op)> opAllowsReachability;

OpUserInfo(Operation *op)
: op(op), userIter(op->getUsers().begin(), op->getUsers().end(),
opAllowsReachability),
userEnd(op->getUsers().end(), op->getUsers().end(),
opAllowsReachability) {}

bool getAndSetUnvisited() {
if (unvisited) {
unvisited = false;
return true;
}
};
return false;
}

private:
bool unvisited = true;
};

/// Lower FirRegOp to `sv.reg` and `sv.always`.
Expand Down

0 comments on commit 12a9689

Please sign in to comment.