Skip to content

[IR] Add easy-build for "if" and "for“ for general structured control flow Ops #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: yijie/easy-build
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
finish if
  • Loading branch information
Mei, Yijie committed May 10, 2024
commit 24e36519f9da0501c4d46b7b15a34b948de4910b
14 changes: 14 additions & 0 deletions include/gc/IR/EasyBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <stddef.h>

namespace mlir {
namespace scf {
class YieldOp;
}

namespace easybuild {

namespace impl {
Expand Down Expand Up @@ -58,6 +62,10 @@ struct EasyBuilder {
: builder{builder} {}
void setLoc(const Location &l) { builder->loc = l; }

Operation *getLastOperaion() {
return &*(--builder->builder.getInsertionPoint());
}

template <typename W, typename V> auto wrapOrFail(V &&v) {
return W::wrapOrFail(builder, std::forward<V>(v));
}
Expand Down Expand Up @@ -91,6 +99,12 @@ struct EasyBuilder {
builder->builder.create<OP>(builder->loc, std::forward<Args>(v)...));
}
}

template <typename OP = scf::YieldOp, typename... Args>
auto yield(Args &&...v) {
builder->builder.create<OP>(builder->loc,
ValueRange{std::forward<Args>(v)...});
}
};

} // namespace easybuild
Expand Down
84 changes: 79 additions & 5 deletions include/gc/IR/EasyBuildSCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include "mlir/Interfaces/LoopLikeInterface.h"

namespace mlir {
namespace scf {
class IfOp;
}

namespace easybuild {
namespace impl {

Expand All @@ -28,9 +32,7 @@ struct ForRangeSimulatorImpl {
: s{s}, op{op} {
s->builder.setInsertionPointToStart(&op.getLoopRegions().front()->front());
}
~ForRangeSimulatorImpl() {
s->builder.setInsertionPointAfter(op);
}
~ForRangeSimulatorImpl() { s->builder.setInsertionPointAfter(op); }
};

template <int N, typename... Ts>
Expand Down Expand Up @@ -84,8 +86,7 @@ template <typename... TArgs> struct ForRangeSimulator : ForRangeSimulatorImpl {
return consumed != other.consumed;
}

ForRangeIterator(ForRangeSimulator *ptr)
: ptr{ptr}, consumed{false} {}
ForRangeIterator(ForRangeSimulator *ptr) : ptr{ptr}, consumed{false} {}
ForRangeIterator() : ptr{nullptr}, consumed{true} {}
};

Expand All @@ -107,6 +108,79 @@ auto forRangeIn(const EasyBuilder &s, LoopLikeOpInterface op) {

#define EB_for for

namespace impl {
struct IfSimulator;
struct IfIterator {
IfSimulator *ptr;
int index;
int operator*() const;

IfIterator &operator++() {
index++;
return *this;
}

bool operator!=(IfIterator &other) const { return index != other.index; }

IfIterator(IfSimulator *ptr) : ptr{ptr}, index{0} {}
IfIterator(int numRegions) : ptr{nullptr}, index{numRegions} {}
};

struct IfSimulator {
StatePtr s;
Operation *op;
IfIterator begin() { return IfIterator(this); }
IfIterator end() {
int nonEmptyRegions = 0;
for (auto &reg : op->getRegions()) {
if (reg.begin() != reg.end()) {
nonEmptyRegions++;
}
}
return IfIterator(nonEmptyRegions);
}
~IfSimulator() { s->builder.setInsertionPointAfter(op); }
};
inline int IfIterator::operator*() const {
auto &blocks = ptr->op->getRegion(index);
ptr->s->builder.setInsertionPointToStart(&blocks.back());
return index;
}

} // namespace impl

impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing an inline here~

return impl::IfSimulator{s.builder, op};
}

template <typename T = scf::IfOp>
impl::IfSimulator makeScfIfLikeRange(EBValue cond, TypeRange resultTypes) {
auto &s = cond.builder;
auto op = s->builder.create<T>(s->loc, resultTypes, cond, true);
return impl::IfSimulator{s, op};
}

template <typename T = scf::IfOp>
impl::IfSimulator makeScfIfLikeRange(EBValue cond, bool hasElse = true) {
auto &s = cond.builder;
auto op = s->builder.create<T>(s->loc, TypeRange{}, cond, hasElse);
return impl::IfSimulator{s, op};
}

#define EB_if(BUILDER, ...) \
for (auto &&eb_mlir_if_scope__ : \
::mlir::easybuild::makeIfRange(BUILDER, __VA_ARGS__)) \
if (eb_mlir_if_scope__ == 0)

// EB_scf_if(COND)
// EB_scf_if(COND, HAS_ELSE)
// EB_scf_if(COND, RESULT_TYPES)
#define EB_scf_if(...) \
for (auto &&eb_mlir_if_scope__ : \
::mlir::easybuild::makeScfIfLikeRange(__VA_ARGS__)) \
if (eb_mlir_if_scope__ == 0)
#define EB_else else

} // namespace easybuild
} // namespace mlir
#endif
50 changes: 46 additions & 4 deletions unittests/Dialect/SCF/EasyBuildTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SCFTest : public ::testing::Test {
protected:
SCFTest() {
context.getOrLoadDialect<scf::SCFDialect>();
context.getOrLoadDialect<arith::ArithDialect>();
context.getOrLoadDialect<func::FuncDialect>();
}

Expand All @@ -44,9 +45,31 @@ TEST_F(SCFTest, EasyBuild) {
/*ind_var*/ ValueRange{init_val});
EB_for(auto &&[idx, redu] : forRangeIn<EBUnsigned, EBUnsigned>(b, loop)) {
auto idx2 = idx + b.toIndex(1);
b.F<scf::YieldOp, void>(ValueRange{idx2 + redu});
EB_scf_if(idx2 == b.toIndex(10), false) {
// if without else
b.toIndex(1123);
}
EB_scf_if(idx2 == b.toIndex(12)) {
// if-else, no return value
b.toIndex(1123);
}
EB_else {
// else
b.toIndex(11234);
}
EB_scf_if(idx2 == b.toIndex(14), {builder.getIndexType()}) {
// if-else with return value
b.yield(idx);
}
EB_else {
// else with return value
b.yield(idx2);
}
auto ifResult = b.wrap<EBUnsigned>(b.getLastOperaion()->getResult(0));
b.yield(ifResult + redu);
}
builder.create<func::ReturnOp>(loc, loop.getResult(0));
b.yield<func::ReturnOp>(loop.getResult(0));

std::string out;
llvm::raw_string_ostream os{out};
os << func;
Expand All @@ -61,8 +84,27 @@ TEST_F(SCFTest, EasyBuild) {
%1 = scf.for %arg1 = %c0 to %c10_0 step %c1 iter_args(%arg2 = %0) -> (index) {
%c1_1 = arith.constant 1 : index
%2 = arith.addi %arg1, %c1_1 : index
%3 = arith.addi %2, %arg2 : index
scf.yield %3 : index
%c10_2 = arith.constant 10 : index
%3 = arith.cmpi eq, %2, %c10_2 : index
scf.if %3 {
%c1123 = arith.constant 1123 : index
}
%c12 = arith.constant 12 : index
%4 = arith.cmpi eq, %2, %c12 : index
scf.if %4 {
%c1123 = arith.constant 1123 : index
} else {
%c11234 = arith.constant 11234 : index
}
%c14 = arith.constant 14 : index
%5 = arith.cmpi eq, %2, %c14 : index
%6 = scf.if %5 -> (index) {
scf.yield %arg1 : index
} else {
scf.yield %2 : index
}
%7 = arith.addi %6, %arg2 : index
scf.yield %7 : index
}
return %1 : index
})mlir";
Expand Down