Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HandshakeToFIRRTL] Lower handshake.instance operations #2067

Merged
merged 2 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions lib/Conversion/HandshakeToFIRRTL/HandshakeToFIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ static std::string getTypeName(Operation *oldOp, Type type) {

/// Construct a name for creating FIRRTL sub-module.
static std::string getSubModuleName(Operation *oldOp) {
if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
return instanceOp.getModule().str();

// The dialect name is separated from the operation name by '.', which is not
// valid in SystemVerilog module names. In case this name is used in
// SystemVerilog output, replace '.' with '_'.
Expand Down Expand Up @@ -527,17 +530,17 @@ static FModuleOp createTopModuleOp(handshake::FuncOp funcOp, unsigned numClocks,
// FIRRTL Sub-module Related Functions
//===----------------------------------------------------------------------===//

/// Check whether a submodule with the same name has been created elsewhere.
/// Return the matched submodule if true, otherwise return nullptr.
static FModuleOp checkSubModuleOp(FModuleOp topModuleOp, Operation *oldOp) {
for (auto &op : topModuleOp->getParentRegion()->front()) {
if (auto subModuleOp = dyn_cast<FModuleOp>(op)) {
if (getSubModuleName(oldOp) == subModuleOp.getName()) {
return subModuleOp;
}
}
}
return FModuleOp(nullptr);
/// Check whether a submodule with the same name has been created elsewhere in
/// the FIRRTL circt. Return the matched submodule if true, otherwise return
/// nullptr.
static FModuleOp checkSubModuleOp(CircuitOp circuitOp, Operation *oldOp) {
auto moduleOp = circuitOp.lookupSymbol<FModuleOp>(getSubModuleName(oldOp));

if (isa<handshake::InstanceOp>(oldOp))
assert(moduleOp &&
"handshake.instance target modules should always have been lowered "
"before the modules that reference them!");
return moduleOp;
}

/// All standard expressions and handshake elastic components will be converted
Expand Down Expand Up @@ -2241,7 +2244,7 @@ struct HandshakeFuncOpLowering : public OpConversionPattern<handshake::FuncOp> {
// This branch takes care of all non-timing operations that require to
// be instantiated in the top-module.
else if (op.getDialect()->getNamespace() != "firrtl") {
FModuleOp subModuleOp = checkSubModuleOp(topModuleOp, &op);
FModuleOp subModuleOp = checkSubModuleOp(circuitOp, &op);
bool hasClock = op.hasTrait<mlir::OpTrait::HasClock>();

// Check if the sub-module already exists.
Expand Down Expand Up @@ -2285,10 +2288,12 @@ using InstanceGraph = std::map<std::string, std::set<std::string>>;

/// Iterates over the handshake::FuncOp's in the program to build an instance
/// graph. In doing so, we detect whether there are any cycles in this graph, as
/// well as infer a top module for the design.
static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
InstanceGraph &instanceGraph,
std::string &topLevel) {
/// well as infer a top module for the design by performing a topological sort
/// of the instance graph. The result of this sort is placed in sortedFuncs.
static LogicalResult
resolveInstanceGraph(ModuleOp moduleOp, InstanceGraph &instanceGraph,
std::string &topLevel,
SmallVectorImpl<std::string> &sortedFuncs) {
// Create use graph
auto walkFuncOps = [&](handshake::FuncOp funcOp) {
auto &funcUses = instanceGraph[funcOp.getName().str()];
Expand All @@ -2302,7 +2307,7 @@ static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
// instances as candidate top level modules; these will be pruned whenever
// they are referenced by another module.
std::set<std::string> visited, marked, candidateTopLevel;
SmallVector<std::string> sorted, cycleTrace;
SmallVector<std::string> cycleTrace;
bool cyclic = false;
llvm::transform(instanceGraph,
std::inserter(candidateTopLevel, candidateTopLevel.begin()),
Expand All @@ -2324,7 +2329,7 @@ static LogicalResult resolveInstanceGraph(ModuleOp moduleOp,
}
marked.erase(node);
visited.insert(node);
sorted.insert(sorted.begin(), node);
sortedFuncs.insert(sortedFuncs.begin(), node);
};
for (auto it : instanceGraph) {
if (visited.count(it.first) == 0)
Expand Down Expand Up @@ -2368,7 +2373,8 @@ class HandshakeToFIRRTLPass
// Resolve the instance graph to get a top-level module.
std::string topLevel;
InstanceGraph uses;
if (resolveInstanceGraph(op, uses, topLevel).failed()) {
SmallVector<std::string> sortedFuncs;
if (resolveInstanceGraph(op, uses, topLevel, sortedFuncs).failed()) {
signalPassFailure();
return;
}
Expand All @@ -2383,11 +2389,21 @@ class HandshakeToFIRRTLPass
target.addLegalDialect<FIRRTLDialect>();
target.addIllegalDialect<handshake::HandshakeDialect>();

RewritePatternSet patterns(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext(), circuitOp);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
// Convert the handshake.func operations in post-order wrt. the instance
// graph. This ensures that any referenced submodules (through
// handshake.instance) has already been lowered, and their FIRRTL module
// equivalents are available.
for (auto funcName : llvm::reverse(sortedFuncs)) {
RewritePatternSet patterns(op.getContext());
patterns.insert<HandshakeFuncOpLowering>(op.getContext(), circuitOp);
auto funcOp = op.lookupSymbol(funcName);
assert(funcOp && "Symbol not found in module!");
if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
funcOp->emitOpError() << "error during conversion";
return;
}
}
}
};
} // end anonymous namespace
Expand Down
56 changes: 56 additions & 0 deletions test/Conversion/HandshakeToFIRRTL/test_instance.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: circt-opt -lower-handshake-to-firrtl %s | FileCheck %s

// CHECK: firrtl.module @foo(in %[[VAL_0:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_1:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_2:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_3:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_4:.*]]: !firrtl.clock, in %[[VAL_5:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = firrtl.instance "" @bar(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_12:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_6]], %[[VAL_0]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_7]], %[[VAL_1]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_10]], %[[VAL_4]] : !firrtl.clock, !firrtl.clock
// CHECK: firrtl.connect %[[VAL_11]], %[[VAL_5]] : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.connect %[[VAL_12]], %[[VAL_9]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_2]], %[[VAL_8]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_3]], %[[VAL_1]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

// CHECK: firrtl.module @handshake_sink_1ins_0outs_ctrl(in %[[VAL_13:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>) {

// CHECK: firrtl.module @bar(in %[[VAL_16:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_17:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_18:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_19:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_20:.*]]: !firrtl.clock, in %[[VAL_21:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = firrtl.instance "" @baz(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_28:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_22]], %[[VAL_16]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_23]], %[[VAL_17]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_26]], %[[VAL_20]] : !firrtl.clock, !firrtl.clock
// CHECK: firrtl.connect %[[VAL_27]], %[[VAL_21]] : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK: firrtl.connect %[[VAL_28]], %[[VAL_25]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: firrtl.connect %[[VAL_18]], %[[VAL_24]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_19]], %[[VAL_17]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

// CHECK: firrtl.module @arith_addi_in_ui32_ui32_out_ui32(in %[[VAL_29:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_30:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_31:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>) {

// CHECK: firrtl.module @baz(in %[[VAL_45:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_46:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_47:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_48:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_49:.*]]: !firrtl.clock, in %[[VAL_50:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_51:.*]], %[[VAL_52:.*]], %[[VAL_53:.*]] = firrtl.instance "" @arith_addi_in_ui32_ui32_out_ui32(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>)
// CHECK: firrtl.connect %[[VAL_51]], %[[VAL_45]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_52]], %[[VAL_45]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_47]], %[[VAL_53]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_48]], %[[VAL_46]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
// CHECK: }

module {
handshake.func @baz(%a: i32, %ctrl : none) -> (i32, none) {
%0 = arith.addi %a, %a : i32
handshake.return %0, %ctrl : i32, none
}

handshake.func @bar(%a: i32, %ctrl : none) -> (i32, none) {
%c:2 = handshake.instance @baz(%a, %ctrl) : (i32, none) -> (i32, none)
"handshake.sink"(%c#1) {control = true} : (none) -> ()
handshake.return %c#0, %ctrl : i32, none
}

handshake.func @foo(%a: i32, %ctrl : none) -> (i32, none) {
%b:2 = handshake.instance @bar(%a, %ctrl) : (i32, none) -> (i32, none)
"handshake.sink"(%b#1) {control = true} : (none) -> ()
handshake.return %b#0, %ctrl : i32, none
}
}