Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HandshakeToFIRRTL] Add top module inference and instance cycle detection #2056

Merged
merged 4 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 101 additions & 4 deletions lib/Conversion/HandshakeToFIRRTL/HandshakeToFIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MathExtras.h"

#include <set>

using namespace mlir;
using namespace circt;
using namespace circt::handshake;
Expand Down Expand Up @@ -2208,13 +2210,12 @@ static void convertReturnOp(Operation *oldOp, FModuleOp topModuleOp,
/// Please refer to test_addi.mlir test case.
struct HandshakeFuncOpLowering : public OpConversionPattern<handshake::FuncOp> {
using OpConversionPattern<handshake::FuncOp>::OpConversionPattern;
HandshakeFuncOpLowering(MLIRContext *context, CircuitOp circuitOp)
: OpConversionPattern<handshake::FuncOp>(context), circuitOp(circuitOp) {}

LogicalResult
matchAndRewrite(handshake::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Create FIRRTL circuit and top-module operation.
auto circuitOp = rewriter.create<CircuitOp>(
funcOp.getLoc(), rewriter.getStringAttr(funcOp.getName()));
rewriter.setInsertionPointToStart(circuitOp.getBody());
auto topModuleOp = createTopModuleOp(funcOp, /*numClocks=*/1, rewriter);

Expand Down Expand Up @@ -2259,21 +2260,117 @@ struct HandshakeFuncOpLowering : public OpConversionPattern<handshake::FuncOp> {

return success();
}

private:
/// Top level FIRRTL circuit operation, which we'll emit into. Marked as
/// mutable due to circuitOp.getBody() being non-const.
mutable CircuitOp circuitOp;
};

using InstanceGraph = std::map<std::string, std::set<std::string>>;

/// Iterates over the handshake::FuncOp's in the program to build an instance
/// graph. In doing so, we detect whether there are any cycles in this graph, as
/// well as infer a top module for the design.
static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
InstanceGraph &instanceGraph,
std::string &topLevel) {
// Create use graph
auto walkFuncOps = [&](handshake::FuncOp funcOp) {
auto &funcUses = instanceGraph[funcOp.getName().str()];
funcOp.walk([&](handshake::InstanceOp instanceOp) {
funcUses.insert(instanceOp.getModule().str());
});
};
moduleOp.walk(walkFuncOps);

// find top-level (and cycles) using a topological sort. Initialize all
// instances as candidate top level modules; these will be pruned whenever
// they are referenced by another module.
std::set<std::string> visited, marked, candidateTopLevel;
SmallVector<std::string> sorted, cycleTrace;
bool cyclic = false;
llvm::transform(instanceGraph,
std::inserter(candidateTopLevel, candidateTopLevel.begin()),
[](auto it) { return it.first; });
std::function<void(const std::string &, SmallVector<std::string>)> cycleUtil =
[&](const std::string &node, SmallVector<std::string> trace) {
if (cyclic || visited.count(node))
return;
trace.push_back(node);
if (marked.count(node)) {
cyclic = true;
cycleTrace = trace;
return;
}
marked.insert(node);
for (auto use : instanceGraph[node]) {
candidateTopLevel.erase(use);
cycleUtil(use, trace);
}
marked.erase(node);
visited.insert(node);
sorted.insert(sorted.begin(), node);
};
for (auto it : instanceGraph) {
if (visited.count(it.first) == 0)
cycleUtil(it.first, {});
if (cyclic)
break;
}

if (cyclic) {
auto err = moduleOp.emitOpError();
err << "cannot lower handshake program - cycle "
"detected in instance graph (";
llvm::interleave(
cycleTrace, err, [&](auto node) { err << node; }, "->");
err << ").";
return err;
}
assert(!candidateTopLevel.empty() &&
"if non-cyclic, there should be at least 1 candidate top level");

if (candidateTopLevel.size() > 1) {
auto err = moduleOp.emitOpError();
err << "multiple candidate top-level modules detected (";
llvm::interleaveComma(candidateTopLevel, err,
[&](auto topLevel) { err << topLevel; });
err << "). Please remove one of these from the source program.";
return err;
}
topLevel = *candidateTopLevel.begin();
return success();
}

namespace {
class HandshakeToFIRRTLPass
: public HandshakeToFIRRTLBase<HandshakeToFIRRTLPass> {
public:
void runOnOperation() override {
auto op = getOperation();
auto *ctx = op.getContext();

// Resolve the instance graph to get a top-level module.
std::string topLevel;
InstanceGraph uses;
if (resolveInstanceGraph(op, uses, topLevel).failed()) {
signalPassFailure();
return;
}

// Create FIRRTL circuit op.
OpBuilder builder(ctx);
builder.setInsertionPointToStart(op.getBody());
auto circuitOp =
builder.create<CircuitOp>(op.getLoc(), builder.getStringAttr(topLevel));

ConversionTarget target(getContext());
target.addLegalDialect<FIRRTLDialect>();
target.addIllegalDialect<handshake::HandshakeDialect>();

RewritePatternSet patterns(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext(), circuitOp);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
Expand Down
40 changes: 40 additions & 0 deletions test/Conversion/HandshakeToFIRRTL/errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: circt-opt -lower-handshake-to-firrtl -verify-diagnostics -split-input-file %s

// Test cycle through a component
// expected-error @+1 {{'builtin.module' op cannot lower handshake program - cycle detected in instance graph (bar->baz->foo->bar).}}
module {
handshake.func @bar(%ctrl : none) -> (none) {
%0 = handshake.instance @baz(%ctrl) : (none) -> (none)
handshake.return %0: none
}

handshake.func @foo(%ctrl : none) -> (none) {
%0 = handshake.instance @bar(%ctrl) : (none) -> (none)
handshake.return %0: none
}

handshake.func @baz(%ctrl : none) -> (none) {
%0 = handshake.instance @foo(%ctrl) : (none) -> (none)
handshake.return %0: none
}
}

// -----

// test multiple candidate top components
// expected-error @+1 {{'builtin.module' op multiple candidate top-level modules detected (bar, foo). Please remove one of these from the source program.}}
module {
handshake.func @bar(%ctrl : none) -> (none) {
%0 = handshake.instance @baz(%ctrl) : (none) -> (none)
handshake.return %0: none
}

handshake.func @foo(%ctrl : none) -> (none) {
%0 = handshake.instance @baz(%ctrl) : (none) -> (none)
handshake.return %0: none
}

handshake.func @baz(%ctrl : none) -> (none) {
handshake.return %ctrl: none
}
}