From 3c12682a1aff12249d89fb01cfa0554c0ea28f1b Mon Sep 17 00:00:00 2001 From: Jiahan Xie <88367305+jiahanxie353@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:47:59 -0400 Subject: [PATCH] Support `scf.if` Op Lowering to Calyx (#6256) * support lowering scf if op and add a corresponding test --- lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 234 ++++++++++++++++-- .../SCFToCalyx/convert_controlflow.mlir | 69 ++++++ 2 files changed, 284 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index c780492bc36a..7411c3a4dd3b 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -95,6 +95,10 @@ class ScfForOp : public calyx::RepeatOpInterface { // Lowering state classes //===----------------------------------------------------------------------===// +struct IfScheduleable { + scf::IfOp ifOp; +}; + struct WhileScheduleable { /// While operation to schedule. ScfWhileOp whileOp; @@ -115,8 +119,63 @@ struct CallScheduleable { }; /// A variant of types representing scheduleable operations. -using Scheduleable = std::variant; +using Scheduleable = + std::variant; + +class IfLoweringStateInterface { +public: + void setThenGroup(scf::IfOp op, calyx::GroupOp group) { + Operation *operation = op.getOperation(); + assert(thenGroup.count(operation) == 0 && + "A then group was already set for this scf::IfOp!\n"); + thenGroup[operation] = group; + } + + calyx::GroupOp getThenGroup(scf::IfOp op) { + auto it = thenGroup.find(op.getOperation()); + assert(it != thenGroup.end() && + "No then group was set for this scf::IfOp!\n"); + return it->second; + } + + void setElseGroup(scf::IfOp op, calyx::GroupOp group) { + Operation *operation = op.getOperation(); + assert(elseGroup.count(operation) == 0 && + "An else group was already set for this scf::IfOp!\n"); + elseGroup[operation] = group; + } + + calyx::GroupOp getElseGroup(scf::IfOp op) { + auto it = elseGroup.find(op.getOperation()); + assert(it != elseGroup.end() && + "No else group was set for this scf::IfOp!\n"); + return it->second; + } + + void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx) { + assert(resultRegs[op.getOperation()].count(idx) == 0 && + "A register was already registered for the given yield result.\n"); + assert(idx < op->getNumOperands()); + resultRegs[op.getOperation()][idx] = reg; + } + + const DenseMap &getResultRegs(scf::IfOp op) { + return resultRegs[op.getOperation()]; + } + + calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx) { + auto regs = getResultRegs(op); + auto it = regs.find(idx); + assert(it != regs.end() && "resultReg not found"); + return it->second; + } + +private: + DenseMap thenGroup; + DenseMap elseGroup; + DenseMap> resultRegs; +}; class WhileLoopLoweringStateInterface : calyx::LoopLoweringStateInterface { @@ -187,6 +246,7 @@ class ForLoopLoweringStateInterface class ComponentLoweringState : public calyx::ComponentLoweringStateInterface, public WhileLoopLoweringStateInterface, public ForLoopLoweringStateInterface, + public IfLoweringStateInterface, public calyx::SchedulerInterface { public: ComponentLoweringState(calyx::ComponentOp component) @@ -213,7 +273,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { TypeSwitch(_op) .template Case(yieldOp->getParentOp()); - if (!whileOp) { - return yieldOp.getOperation()->emitError() - << "Currently only support yield operations inside for and while " - "loops."; - } - ScfWhileOp whileOpInterface(whileOp); - - auto assignGroup = - getState().buildWhileLoopIterArgAssignments( - rewriter, whileOpInterface, - getState().getComponentOp(), - getState().getUniqueName(whileOp) + "_latch", - yieldOp->getOpOperands()); - getState().setWhileLoopLatchGroup(whileOpInterface, - assignGroup); + if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { + ScfWhileOp whileOpInterface(whileOp); + + auto assignGroup = + getState().buildWhileLoopIterArgAssignments( + rewriter, whileOpInterface, + getState().getComponentOp(), + getState().getUniqueName(whileOp) + + "_latch", + yieldOp->getOpOperands()); + getState().setWhileLoopLatchGroup(whileOpInterface, + assignGroup); + return success(); + } + + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + auto resultRegs = getState().getResultRegs(ifOp); + + if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) { + auto thenGroup = getState().getThenGroup(ifOp); + for (auto op : enumerate(yieldOp.getOperands())) { + auto resultReg = + getState().getResultRegs(ifOp, op.index()); + buildAssignmentsForRegisterWrite( + rewriter, thenGroup, + getState().getComponentOp(), resultReg, + op.value()); + getState().registerEvaluatingGroup( + ifOp.getResult(op.index()), thenGroup); + } + } + + if (!ifOp.getElseRegion().empty() && + (yieldOp->getParentRegion() == &ifOp.getElseRegion())) { + auto elseGroup = getState().getElseGroup(ifOp); + for (auto op : enumerate(yieldOp.getOperands())) { + auto resultReg = + getState().getResultRegs(ifOp, op.index()); + buildAssignmentsForRegisterWrite( + rewriter, elseGroup, + getState().getComponentOp(), resultReg, + op.value()); + getState().registerEvaluatingGroup( + ifOp.getResult(op.index()), elseGroup); + } + } + } return success(); } @@ -945,6 +1037,13 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, return success(); } +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, + scf::IfOp ifOp) const { + getState().addBlockScheduleable( + ifOp.getOperation()->getBlock(), IfScheduleable{ifOp}); + return success(); +} + LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, CallOp callOp) const { std::string instanceName = calyx::getInstanceName(callOp); @@ -1291,6 +1390,51 @@ class BuildForGroups : public calyx::FuncOpPartialLoweringPattern { } }; +class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern { + using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern; + + LogicalResult + partiallyLowerFuncToComp(FuncOp funcOp, + PatternRewriter &rewriter) const override { + LogicalResult res = success(); + funcOp.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + + auto scfIfOp = cast(op); + + calyx::ComponentOp componentOp = + getState().getComponentOp(); + + std::string thenGroupName = + getState().getUniqueName("then_br"); + auto thenGroupOp = calyx::createGroup( + rewriter, componentOp, scfIfOp.getLoc(), thenGroupName); + getState().setThenGroup(scfIfOp, thenGroupOp); + + if (!scfIfOp.getElseRegion().empty()) { + std::string elseGroupName = + getState().getUniqueName("else_br"); + auto elseGroupOp = calyx::createGroup( + rewriter, componentOp, scfIfOp.getLoc(), elseGroupName); + getState().setElseGroup(scfIfOp, elseGroupOp); + } + + for (auto ifOpRes : scfIfOp.getResults()) { + auto reg = createRegister( + scfIfOp.getLoc(), rewriter, getComponent(), + ifOpRes.getType().getIntOrFloatBitWidth(), + getState().getUniqueName("if_res")); + getState().setResultRegs( + scfIfOp, reg, ifOpRes.getResultNumber()); + } + + return WalkResult::advance(); + }); + return res; + } +}; + /// Builds a control schedule by traversing the CFG of the function and /// associating this with the previously created groups. /// For simplicity, the generated control flow is expanded for all possible @@ -1384,6 +1528,50 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { forLatchGroup.getName()); if (res.failed()) return res; + } else if (auto *ifSchedPtr = std::get_if(&group); + ifSchedPtr) { + auto ifOp = ifSchedPtr->ifOp; + + Location loc = ifOp->getLoc(); + + auto cond = ifOp.getCondition(); + auto condGroup = getState() + .getEvaluatingGroup(cond); + + auto symbolAttr = FlatSymbolRefAttr::get( + StringAttr::get(getContext(), condGroup.getSymName())); + + bool initElse = !ifOp.getElseRegion().empty(); + auto ifCtrlOp = rewriter.create( + loc, cond, symbolAttr, /*initializeElseBody=*/initElse); + + rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock()); + + auto thenSeqOp = + rewriter.create(ifOp.getThenRegion().getLoc()); + auto *thenSeqOpBlock = thenSeqOp.getBodyBlock(); + + rewriter.setInsertionPointToEnd(thenSeqOpBlock); + + calyx::GroupOp thenGroup = + getState().getThenGroup(ifOp); + rewriter.create(thenGroup.getLoc(), + thenGroup.getName()); + + if (!ifOp.getElseRegion().empty()) { + rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody()); + + auto elseSeqOp = + rewriter.create(ifOp.getElseRegion().getLoc()); + auto *elseSeqOpBlock = elseSeqOp.getBodyBlock(); + + rewriter.setInsertionPointToEnd(elseSeqOpBlock); + + calyx::GroupOp elseGroup = + getState().getElseGroup(ifOp); + rewriter.create(elseGroup.getLoc(), + elseGroup.getName()); + } } else if (auto *callSchedPtr = std::get_if(&group)) { auto instanceOp = callSchedPtr->instanceOp; OpBuilder::InsertionGuard g(rewriter); @@ -1540,6 +1728,12 @@ class LateSSAReplacement : public calyx::FuncOpPartialLoweringPattern { LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override { + funcOp.walk([&](scf::IfOp op) { + for (auto res : getState().getResultRegs(op)) + op.getOperation()->getResults()[res.first].replaceAllUsesWith( + res.second.getOut()); + }); + funcOp.walk([&](scf::WhileOp op) { /// The yielded values returned from the while op will be present in the /// iterargs registers post execution of the loop. @@ -1790,6 +1984,8 @@ void SCFToCalyxPass::runOnOperation() { addOncePattern(loweringPatterns, patternState, funcMap, *loweringState); + addOncePattern(loweringPatterns, patternState, funcMap, + *loweringState); /// This pattern converts operations within basic blocks to Calyx library /// operators. Combinational operations are assigned inside a /// calyx::CombGroupOp, and sequential inside calyx::GroupOps. diff --git a/test/Conversion/SCFToCalyx/convert_controlflow.mlir b/test/Conversion/SCFToCalyx/convert_controlflow.mlir index 09e7ef4214fb..d4a87139f621 100644 --- a/test/Conversion/SCFToCalyx/convert_controlflow.mlir +++ b/test/Conversion/SCFToCalyx/convert_controlflow.mlir @@ -572,3 +572,72 @@ module { return } } + +// ----- + +// Test if op with else branch. + +module { +// CHECK-LABEL: calyx.component @main( +// CHECK-SAME: %[[VAL_0:in0]]: i32, +// CHECK-SAME: %[[VAL_1:in1]]: i32, +// CHECK-SAME: %[[VAL_2:.*]]: i1 {clk}, +// CHECK-SAME: %[[VAL_3:.*]]: i1 {reset}, +// CHECK-SAME: %[[VAL_4:.*]]: i1 {go}) -> ( +// CHECK-SAME: %[[VAL_5:out0]]: i32, +// CHECK-SAME: %[[VAL_6:.*]]: i1 {done}) { +// CHECK: %[[VAL_7:.*]] = hw.constant true +// CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]] = calyx.std_add @std_add_0 : i32, i32, i32 +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = calyx.std_slt @std_slt_0 : i32, i32, i1 +// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]], %[[VAL_18:.*]], %[[VAL_19:.*]] = calyx.register @if_res_0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: calyx.wires { +// CHECK: calyx.assign %[[VAL_5]] = %[[VAL_24]] : i32 +// CHECK: calyx.group @then_br_0 { +// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_10]] : i32 +// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1 +// CHECK: calyx.assign %[[VAL_8]] = %[[VAL_0]] : i32 +// CHECK: calyx.assign %[[VAL_9]] = %[[VAL_1]] : i32 +// CHECK: calyx.group_done %[[VAL_19]] : i1 +// CHECK: } +// CHECK: calyx.group @else_br_0 { +// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_1]] : i32 +// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1 +// CHECK: calyx.group_done %[[VAL_19]] : i1 +// CHECK: } +// CHECK: calyx.comb_group @bb0_0 { +// CHECK: calyx.assign %[[VAL_11]] = %[[VAL_0]] : i32 +// CHECK: calyx.assign %[[VAL_12]] = %[[VAL_1]] : i32 +// CHECK: } +// CHECK: calyx.group @ret_assign_0 { +// CHECK: calyx.assign %[[VAL_20]] = %[[VAL_18]] : i32 +// CHECK: calyx.assign %[[VAL_21]] = %[[VAL_7]] : i1 +// CHECK: calyx.group_done %[[VAL_25]] : i1 +// CHECK: } +// CHECK: } +// CHECK: calyx.control { +// CHECK: calyx.seq { +// CHECK: calyx.if %[[VAL_13]] with @bb0_0 { +// CHECK: calyx.seq { +// CHECK: calyx.enable @then_br_0 +// CHECK: } +// CHECK: } else { +// CHECK: calyx.seq { +// CHECK: calyx.enable @else_br_0 +// CHECK: } +// CHECK: } +// CHECK: calyx.enable @ret_assign_0 +// CHECK: } +// CHECK: } +// CHECK: } {toplevel} + func.func @main(%arg0 : i32, %arg1 : i32) -> i32 { + %0 = arith.cmpi slt, %arg0, %arg1 : i32 + %1 = scf.if %0 -> i32 { + %3 = arith.addi %arg0, %arg1 : i32 + scf.yield %3 : i32 + } else { + scf.yield %arg1 : i32 + } + return %1 : i32 + } +}