Skip to content

Commit

Permalink
support lowering scf if op
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed May 27, 2024
1 parent b0dd65a commit 2564b93
Showing 1 changed file with 217 additions and 19 deletions.
236 changes: 217 additions & 19 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class ScfForOp : public calyx::RepeatOpInterface<scf::ForOp> {
// Lowering state classes
//===----------------------------------------------------------------------===//

struct IfScheduleable {
scf::IfOp ifOp;
};

struct WhileScheduleable {
/// While operation to schedule.
ScfWhileOp whileOp;
Expand All @@ -111,8 +115,63 @@ struct CallScheduleable {
};

/// A variant of types representing scheduleable operations.
using Scheduleable = std::variant<calyx::GroupOp, WhileScheduleable,
ForScheduleable, CallScheduleable>;
using Scheduleable =
std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable,
IfScheduleable, CallScheduleable>;

class IfLoweringStateInterface {
public:
void setThenGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(thenGroup.count(operation) == 0 &&
"A then group was already set for this scf::IfOp!\n");
thenGroup[operation] = group;
}

calyx::GroupOp getThenGroup(scf::IfOp op) {
auto it = thenGroup.find(op.getOperation());
assert(it != thenGroup.end() &&
"No then group was set for this scf::IfOp!\n");
return it->second;
}

void setElseGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(elseGroup.count(operation) == 0 &&
"An else group was already set for this scf::IfOp!\n");
elseGroup[operation] = group;
}

calyx::GroupOp getElseGroup(scf::IfOp op) {
auto it = elseGroup.find(op.getOperation());
assert(it != elseGroup.end() &&
"No else group was set for this scf::IfOp!\n");
return it->second;
}

void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx) {
assert(resultRegs[op.getOperation()].count(idx) == 0 &&
"A register was already registered for the given yield result.\n");
assert(idx < op->getNumOperands());
resultRegs[op.getOperation()][idx] = reg;
}

const DenseMap<unsigned, calyx::RegisterOp> &getResultRegs(scf::IfOp op) {
return resultRegs[op.getOperation()];
}

calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx) {
auto regs = getResultRegs(op);
auto it = regs.find(idx);
assert(it != regs.end() && "resultReg not found");
return it->second;
}

private:
DenseMap<Operation *, calyx::GroupOp> thenGroup;
DenseMap<Operation *, calyx::GroupOp> elseGroup;
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> resultRegs;
};

class WhileLoopLoweringStateInterface
: calyx::LoopLoweringStateInterface<ScfWhileOp> {
Expand Down Expand Up @@ -183,6 +242,7 @@ class ForLoopLoweringStateInterface
class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
public WhileLoopLoweringStateInterface,
public ForLoopLoweringStateInterface,
public IfLoweringStateInterface,
public calyx::SchedulerInterface<Scheduleable> {
public:
ComponentLoweringState(calyx::ComponentOp component)
Expand All @@ -209,7 +269,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
TypeSwitch<mlir::Operation *, bool>(_op)
.template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
/// SCF
scf::YieldOp, scf::WhileOp, scf::ForOp,
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
Expand Down Expand Up @@ -268,6 +328,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
Expand Down Expand Up @@ -716,22 +777,54 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
"loops. Run --scf-for-to-while before running --scf-to-calyx.";
}

auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp());
if (!whileOp) {
return yieldOp.getOperation()->emitError()
<< "Currently only support yield operations inside for and while "
"loops.";
}
ScfWhileOp whileOpInterface(whileOp);

auto assignGroup =
getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
rewriter, whileOpInterface,
getState<ComponentLoweringState>().getComponentOp(),
getState<ComponentLoweringState>().getUniqueName(whileOp) + "_latch",
yieldOp->getOpOperands());
getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
assignGroup);
if (auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
ScfWhileOp whileOpInterface(whileOp);

auto assignGroup =
getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
rewriter, whileOpInterface,
getState<ComponentLoweringState>().getComponentOp(),
getState<ComponentLoweringState>().getUniqueName(whileOp) +
"_latch",
yieldOp->getOpOperands());
getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
assignGroup);
return success();
}

if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);

if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
for (auto op : enumerate(yieldOp.getOperands())) {
auto resultReg =
getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
buildAssignmentsForRegisterWrite(
rewriter, thenGroup,
getState<ComponentLoweringState>().getComponentOp(), resultReg,
op.value());
getState<ComponentLoweringState>().registerEvaluatingGroup(
ifOp.getResult(op.index()), thenGroup);
}
}

if (!ifOp.getElseRegion().empty() &&
(yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
// rewriter.setInsertionPointToEnd(elseGroup.getBodyBlock());
for (auto op : enumerate(yieldOp.getOperands())) {
auto resultReg =
getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
buildAssignmentsForRegisterWrite(
rewriter, elseGroup,
getState<ComponentLoweringState>().getComponentOp(), resultReg,
op.value());
getState<ComponentLoweringState>().registerEvaluatingGroup(
ifOp.getResult(op.index()), elseGroup);
}
}
}
return success();
}

Expand Down Expand Up @@ -941,6 +1034,13 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::IfOp ifOp) const {
getState<ComponentLoweringState>().addBlockScheduleable(
ifOp.getOperation()->getBlock(), IfScheduleable{ifOp});
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
Expand Down Expand Up @@ -1287,6 +1387,52 @@ class BuildForGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
LogicalResult res = success();
funcOp.walk([&](Operation *op) {
if (!isa<scf::IfOp>(op))
return WalkResult::advance();

auto scfIfOp = cast<scf::IfOp>(op);

calyx::ComponentOp componentOp =
getState<ComponentLoweringState>().getComponentOp();

std::string thenGroupName =
getState<ComponentLoweringState>().getUniqueName("then_br");
auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
rewriter, componentOp, scfIfOp.getLoc(), thenGroupName);
getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);

if (!scfIfOp.getElseRegion().empty()) {
std::string elseGroupName =
getState<ComponentLoweringState>().getUniqueName("else_br");
auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
rewriter, componentOp, scfIfOp.getLoc(), elseGroupName);
getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
}

for (auto res : scfIfOp.getResults()) {
auto reg = createRegister(
scfIfOp.getLoc(), rewriter, getComponent(), res.getType(),
getState<ComponentLoweringState>().getUniqueName(
"if_res" + res.getResultNumber()));
reg.dump();
getState<ComponentLoweringState>().setResultRegs(scfIfOp, reg,
res.getResultNumber());
}

return WalkResult::advance();
});
return res;
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -1380,6 +1526,50 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
forLatchGroup.getName());
if (res.failed())
return res;
} else if (auto ifSchedPtr = std::get_if<IfScheduleable>(&group);
ifSchedPtr) {
auto ifOp = ifSchedPtr->ifOp;

Location loc = ifOp->getLoc();

auto cond = ifOp.getCondition();
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

auto symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));

bool initElse = !ifOp.getElseRegion().empty();
auto ifCtrlOp = rewriter.create<calyx::IfOp>(
loc, cond, symbolAttr, /*initializeElseBody=*/initElse);

rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());

auto thenSeqOp =
rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(thenSeqOpBlock);

calyx::GroupOp thenGroup =
getState<ComponentLoweringState>().getThenGroup(ifOp);
rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
thenGroup.getName());

if (!ifOp.getElseRegion().empty()) {
rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());

auto elseSeqOp =
rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(elseSeqOpBlock);

calyx::GroupOp elseGroup =
getState<ComponentLoweringState>().getElseGroup(ifOp);
rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
elseGroup.getName());
}
} else if (auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
auto instanceOp = callSchedPtr->instanceOp;
OpBuilder::InsertionGuard g(rewriter);
Expand Down Expand Up @@ -1536,6 +1726,12 @@ class LateSSAReplacement : public calyx::FuncOpPartialLoweringPattern {

LogicalResult partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &) const override {
funcOp.walk([&](scf::IfOp op) {
for (auto res : getState<ComponentLoweringState>().getResultRegs(op))
op.getOperation()->getResults()[res.first].replaceAllUsesWith(
res.second.getOut());
});

funcOp.walk([&](scf::WhileOp op) {
/// The yielded values returned from the while op will be present in the
/// iterargs registers post execution of the loop.
Expand Down Expand Up @@ -1779,6 +1975,8 @@ void SCFToCalyxPass::runOnOperation() {
addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
*loweringState);
/// This pattern converts operations within basic blocks to Calyx library
/// operators. Combinational operations are assigned inside a
/// calyx::CombGroupOp, and sequential inside calyx::GroupOps.
Expand Down

0 comments on commit 2564b93

Please sign in to comment.