Skip to content

Commit

Permalink
[HandshakeToFIRRTL] Lower handshake.instance operations (#2067)
Browse files Browse the repository at this point in the history
Since we're already relying on instantiating FIRRTL modules for the lowered handshake operations, instantiating an already lowered `handshake.func` module is straight-forward. The important part of this commit is that it ensures modules are lowered in post-order wrt. the instance graph (by iterating through the reversed topological order of the instance graph).
  • Loading branch information
mortbopet authored Nov 1, 2021
1 parent c01e972 commit f8b1161
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 24 deletions.
64 changes: 40 additions & 24 deletions lib/Conversion/HandshakeToFIRRTL/HandshakeToFIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ static std::string getTypeName(Operation *oldOp, Type type) {

/// Construct a name for creating FIRRTL sub-module.
static std::string getSubModuleName(Operation *oldOp) {
if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
return instanceOp.getModule().str();

// The dialect name is separated from the operation name by '.', which is not
// valid in SystemVerilog module names. In case this name is used in
// SystemVerilog output, replace '.' with '_'.
Expand Down Expand Up @@ -527,17 +530,17 @@ static FModuleOp createTopModuleOp(handshake::FuncOp funcOp, unsigned numClocks,
// FIRRTL Sub-module Related Functions
//===----------------------------------------------------------------------===//

/// Check whether a submodule with the same name has been created elsewhere.
/// Return the matched submodule if true, otherwise return nullptr.
static FModuleOp checkSubModuleOp(FModuleOp topModuleOp, Operation *oldOp) {
for (auto &op : topModuleOp->getParentRegion()->front()) {
if (auto subModuleOp = dyn_cast<FModuleOp>(op)) {
if (getSubModuleName(oldOp) == subModuleOp.getName()) {
return subModuleOp;
}
}
}
return FModuleOp(nullptr);
/// Check whether a submodule with the same name has been created elsewhere in
/// the FIRRTL circt. Return the matched submodule if true, otherwise return
/// nullptr.
static FModuleOp checkSubModuleOp(CircuitOp circuitOp, Operation *oldOp) {
auto moduleOp = circuitOp.lookupSymbol<FModuleOp>(getSubModuleName(oldOp));

if (isa<handshake::InstanceOp>(oldOp))
assert(moduleOp &&
"handshake.instance target modules should always have been lowered "
"before the modules that reference them!");
return moduleOp;
}

/// All standard expressions and handshake elastic components will be converted
Expand Down Expand Up @@ -2241,7 +2244,7 @@ struct HandshakeFuncOpLowering : public OpConversionPattern<handshake::FuncOp> {
// This branch takes care of all non-timing operations that require to
// be instantiated in the top-module.
else if (op.getDialect()->getNamespace() != "firrtl") {
FModuleOp subModuleOp = checkSubModuleOp(topModuleOp, &op);
FModuleOp subModuleOp = checkSubModuleOp(circuitOp, &op);
bool hasClock = op.hasTrait<mlir::OpTrait::HasClock>();

// Check if the sub-module already exists.
Expand Down Expand Up @@ -2285,10 +2288,12 @@ using InstanceGraph = std::map<std::string, std::set<std::string>>;

/// Iterates over the handshake::FuncOp's in the program to build an instance
/// graph. In doing so, we detect whether there are any cycles in this graph, as
/// well as infer a top module for the design.
static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
InstanceGraph &instanceGraph,
std::string &topLevel) {
/// well as infer a top module for the design by performing a topological sort
/// of the instance graph. The result of this sort is placed in sortedFuncs.
static LogicalResult
resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph,
std::string &topLevel,
SmallVectorImpl<std::string> &sortedFuncs) {
// Create use graph
auto walkFuncOps = [&](handshake::FuncOp funcOp) {
auto &funcUses = instanceGraph[funcOp.getName().str()];
Expand All @@ -2302,7 +2307,7 @@ static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
// instances as candidate top level modules; these will be pruned whenever
// they are referenced by another module.
std::set<std::string> visited, marked, candidateTopLevel;
SmallVector<std::string> sorted, cycleTrace;
SmallVector<std::string> cycleTrace;
bool cyclic = false;
llvm::transform(instanceGraph,
std::inserter(candidateTopLevel, candidateTopLevel.begin()),
Expand All @@ -2324,7 +2329,7 @@ static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
}
marked.erase(node);
visited.insert(node);
sorted.insert(sorted.begin(), node);
sortedFuncs.insert(sortedFuncs.begin(), node);
};
for (auto it : instanceGraph) {
if (visited.count(it.first) == 0)
Expand Down Expand Up @@ -2368,7 +2373,8 @@ class HandshakeToFIRRTLPass
// Resolve the instance graph to get a top-level module.
std::string topLevel;
InstanceGraph uses;
if (resolveInstanceGraph(op, uses, topLevel).failed()) {
SmallVector<std::string> sortedFuncs;
if (resolveInstanceGraph(op, uses, topLevel, sortedFuncs).failed()) {
signalPassFailure();
return;
}
Expand All @@ -2383,11 +2389,21 @@ class HandshakeToFIRRTLPass
target.addLegalDialect<FIRRTLDialect>();
target.addIllegalDialect<handshake::HandshakeDialect>();

RewritePatternSet patterns(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext(), circuitOp);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
// Convert the handshake.func operations in post-order wrt. the instance
// graph. This ensures that any referenced submodules (through
// handshake.instance) has already been lowered, and their FIRRTL module
// equivalents are available.
for (auto funcName : llvm::reverse(sortedFuncs)) {
RewritePatternSet patterns(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext(), circuitOp);
auto funcOp = op.lookupSymbol(funcName);
assert(funcOp && "Symbol not found in module!");
if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
funcOp->emitOpError() << "error during conversion";
return;
}
}
}
};
} // end anonymous namespace
Expand Down
56 changes: 56 additions & 0 deletions test/Conversion/HandshakeToFIRRTL/test_instance.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: circt-opt -lower-handshake-to-firrtl %s | FileCheck %s

// CHECK: firrtl.module @foo(in %[[VAL_0:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_1:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_2:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_3:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_4:.*]]: !firrtl.clock, in %[[VAL_5:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = firrtl.instance "" @bar(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_12:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_6]], %[[VAL_0]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_7]], %[[VAL_1]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_10]], %[[VAL_4]] : !firrtl.clock, !firrtl.clock
// CHECK: firrtl.connect %[[VAL_11]], %[[VAL_5]] : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.connect %[[VAL_12]], %[[VAL_9]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_2]], %[[VAL_8]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_3]], %[[VAL_1]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

// CHECK: firrtl.module @handshake_sink_1ins_0outs_ctrl(in %[[VAL_13:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>) {

// CHECK: firrtl.module @bar(in %[[VAL_16:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_17:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_18:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_19:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_20:.*]]: !firrtl.clock, in %[[VAL_21:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = firrtl.instance "" @baz(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_28:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_22]], %[[VAL_16]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_23]], %[[VAL_17]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_26]], %[[VAL_20]] : !firrtl.clock, !firrtl.clock
// CHECK: firrtl.connect %[[VAL_27]], %[[VAL_21]] : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.connect %[[VAL_28]], %[[VAL_25]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_18]], %[[VAL_24]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_19]], %[[VAL_17]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

// CHECK: firrtl.module @arith_addi_in_ui32_ui32_out_ui32(in %[[VAL_29:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_30:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_31:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>) {

// CHECK: firrtl.module @baz(in %[[VAL_45:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_46:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_47:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_48:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_49:.*]]: !firrtl.clock, in %[[VAL_50:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_51:.*]], %[[VAL_52:.*]], %[[VAL_53:.*]] = firrtl.instance "" @arith_addi_in_ui32_ui32_out_ui32(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>)
// CHECK: firrtl.connect %[[VAL_51]], %[[VAL_45]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_52]], %[[VAL_45]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_47]], %[[VAL_53]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_48]], %[[VAL_46]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

module {
handshake.func @baz(%a: i32, %ctrl : none) -> (i32, none) {
%0 = arith.addi %a, %a : i32
handshake.return %0, %ctrl : i32, none
}

handshake.func @bar(%a: i32, %ctrl : none) -> (i32, none) {
%c:2 = handshake.instance @baz(%a, %ctrl) : (i32, none) -> (i32, none)
"handshake.sink"(%c#1) {control = true} : (none) -> ()
handshake.return %c#0, %ctrl : i32, none
}

handshake.func @foo(%a: i32, %ctrl : none) -> (i32, none) {
%b:2 = handshake.instance @bar(%a, %ctrl) : (i32, none) -> (i32, none)
"handshake.sink"(%b#1) {control = true} : (none) -> ()
handshake.return %b#0, %ctrl : i32, none
}
}

0 comments on commit f8b1161

Please sign in to comment.