Skip to content

[MLIR][IR] Add easy-builder to simplify IR building in C++ #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
495 changes: 495 additions & 0 deletions mlir/docs/EasyBuilder.md

Large diffs are not rendered by default.

453 changes: 453 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/EasyBuild.h

Large diffs are not rendered by default.

112 changes: 112 additions & 0 deletions mlir/include/mlir/IR/EasyBuild.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines the easy-build utilities core data structures for
// building IR.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_EASYBUILD_H
#define MLIR_IR_EASYBUILD_H
#include "mlir/IR/Builders.h"
#include <cstdint>
#include <memory>
#include <stddef.h>

namespace mlir {
namespace scf {
class YieldOp;
}

namespace easybuild {

namespace impl {
struct EasyBuildState {
OpBuilder &builder;
Location loc;
bool u64AsIndex;
EasyBuildState(OpBuilder &builder, Location loc, bool u64AsIndex)
: builder{builder}, loc{loc}, u64AsIndex{u64AsIndex} {}
};

using StatePtr = std::shared_ptr<impl::EasyBuildState>;

} // namespace impl

struct EBValue {
std::shared_ptr<impl::EasyBuildState> builder;
Value v;
EBValue() = default;
EBValue(const impl::StatePtr &builder, Value v) : builder{builder}, v{v} {}
Value get() const { return v; }
operator Value() const { return v; }

static FailureOr<EBValue> wrapOrFail(const impl::StatePtr &state, Value v) {
return EBValue{state, v};
}
};

struct EBArithValue;

struct EasyBuilder {
std::shared_ptr<impl::EasyBuildState> builder;
EasyBuilder(OpBuilder &builder, Location loc, bool u64AsIndex = false)
: builder{
std::make_shared<impl::EasyBuildState>(builder, loc, u64AsIndex)} {}
EasyBuilder(const std::shared_ptr<impl::EasyBuildState> &builder)
: builder{builder} {}
void setLoc(const Location &l) { builder->loc = l; }

Operation *getLastOperaion() {
return &*(--builder->builder.getInsertionPoint());
}

template <typename W, typename V> auto wrapOrFail(V &&v) {
return W::wrapOrFail(builder, std::forward<V>(v));
}

template <typename W, typename V> auto wrap(V &&v) {
auto ret = wrapOrFail<W>(std::forward<V>(v));
if (failed(ret)) {
llvm_unreachable("wrap failed!");
}
return *ret;
}

template <typename V> auto operator()(V &&v) {
if constexpr (std::is_convertible_v<V, Value>) {
return EBValue{builder, std::forward<V>(v)};
} else {
return wrap<EBArithValue>(std::forward<V>(v));
}
}

template <typename W = EBArithValue> auto toIndex(uint64_t v) const {
return W::toIndex(builder, v);
}

template <typename OP, typename OutT = EBValue, typename... Args>
auto F(Args &&...v) {
if constexpr (std::is_same_v<OutT, void>) {
builder->builder.create<OP>(builder->loc, std::forward<Args>(v)...);
} else {
return wrap<OutT>(
builder->builder.create<OP>(builder->loc, std::forward<Args>(v)...));
}
}

template <typename OP = scf::YieldOp, typename... Args>
auto yield(Args &&...v) {
builder->builder.create<OP>(builder->loc,
ValueRange{std::forward<Args>(v)...});
}
};

} // namespace easybuild
} // namespace mlir
#endif
190 changes: 190 additions & 0 deletions mlir/include/mlir/IR/EasyBuildSCF.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//===- EasyBuildSCF.h - Easy IR Builder for general control flow *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines the helper classes, functions and macros to help to
// build general structured control flow. Developers can use the utilities in
// this header to easily compose control flow IR.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_EASYBUILDSCF_H
#define MLIR_IR_EASYBUILDSCF_H
#include "mlir/IR/EasyBuild.h"
#include "mlir/Interfaces/LoopLikeInterface.h"

namespace mlir {
namespace scf {
class IfOp;
}

namespace easybuild {
namespace impl {

struct ForRangeSimulatorImpl {
StatePtr s;
LoopLikeOpInterface op;
ForRangeSimulatorImpl(const StatePtr &s, LoopLikeOpInterface op)
: s{s}, op{op} {
s->builder.setInsertionPointToStart(&op.getLoopRegions().front()->front());
}
~ForRangeSimulatorImpl() { s->builder.setInsertionPointAfter(op); }
};

template <int N, typename... Ts>
using NthTypeOf = typename std::tuple_element<N, std::tuple<Ts...>>::type;

template <typename... TArgs>
struct ForVarBinding {
ForRangeSimulatorImpl *impl;
template <int I>
auto get() {
using TOut = NthTypeOf<I, TArgs...>;
if (auto wrapped = TOut::wrapOrFail(
impl->s, impl->op.getLoopRegions().front()->front().getArgument(I));
succeeded(wrapped)) {
return *wrapped;
}
llvm_unreachable("Bad cast for the loop iterator");
}
};
} // namespace impl
} // namespace easybuild
} // namespace mlir

namespace std {
template <typename... TArgs>
struct tuple_size<mlir::easybuild::impl::ForVarBinding<TArgs...>>
: std::integral_constant<std::size_t, sizeof...(TArgs)> {};

template <std::size_t I, typename... TArgs>
struct tuple_element<I, mlir::easybuild::impl::ForVarBinding<TArgs...>> {
using type = mlir::easybuild::impl::NthTypeOf<I, TArgs...>;
};
} // namespace std

namespace mlir {
namespace easybuild {

namespace impl {

template <typename... TArgs>
struct ForRangeSimulator : ForRangeSimulatorImpl {
using ForRangeSimulatorImpl::ForRangeSimulatorImpl;
struct ForRangeIterator {
ForRangeSimulatorImpl *ptr;
bool consumed;
auto operator*() const { return ForVarBinding<TArgs...>{ptr}; }

ForRangeIterator &operator++() {
consumed = true;
return *this;
}

bool operator!=(ForRangeIterator &other) const {
return consumed != other.consumed;
}

ForRangeIterator(ForRangeSimulator *ptr) : ptr{ptr}, consumed{false} {}
ForRangeIterator() : ptr{nullptr}, consumed{true} {}
};

ForRangeIterator begin() { return ForRangeIterator(this); }

ForRangeIterator end() { return ForRangeIterator(); }
};
} // namespace impl

template <typename... TArgs>
auto forRangeIn(const impl::StatePtr &s, LoopLikeOpInterface op) {
return impl::ForRangeSimulator<TArgs...>{s, op};
}

template <typename... TArgs>
auto forRangeIn(const EasyBuilder &s, LoopLikeOpInterface op) {
return impl::ForRangeSimulator<TArgs...>{s.builder, op};
}

#define EB_for for

namespace impl {
struct IfSimulator;
struct IfIterator {
IfSimulator *ptr;
int index;
int operator*() const;

IfIterator &operator++() {
index++;
return *this;
}

bool operator!=(IfIterator &other) const { return index != other.index; }

IfIterator(IfSimulator *ptr) : ptr{ptr}, index{0} {}
IfIterator(int numRegions) : ptr{nullptr}, index{numRegions} {}
};

struct IfSimulator {
StatePtr s;
Operation *op;
IfIterator begin() { return IfIterator(this); }
IfIterator end() {
int nonEmptyRegions = 0;
for (auto &reg : op->getRegions()) {
if (reg.begin() != reg.end()) {
nonEmptyRegions++;
}
}
return IfIterator(nonEmptyRegions);
}
~IfSimulator() { s->builder.setInsertionPointAfter(op); }
};
inline int IfIterator::operator*() const {
auto &blocks = ptr->op->getRegion(index);
ptr->s->builder.setInsertionPointToStart(&blocks.back());
return index;
}

} // namespace impl

inline impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
return impl::IfSimulator{s.builder, op};
}

template <typename T = scf::IfOp>
inline impl::IfSimulator makeScfIfLikeRange(EBValue cond,
TypeRange resultTypes) {
auto &s = cond.builder;
auto op = s->builder.create<T>(s->loc, resultTypes, cond, true);
return impl::IfSimulator{s, op};
}

template <typename T = scf::IfOp>
inline impl::IfSimulator makeScfIfLikeRange(EBValue cond, bool hasElse = true) {
auto &s = cond.builder;
auto op = s->builder.create<T>(s->loc, TypeRange{}, cond, hasElse);
return impl::IfSimulator{s, op};
}

#define EB_if(BUILDER, ...) \
for (auto &&eb_mlir_if_scope__ : \
::mlir::easybuild::makeIfRange(BUILDER, __VA_ARGS__)) \
if (eb_mlir_if_scope__ == 0)

// EB_scf_if(COND)
// EB_scf_if(COND, HAS_ELSE)
// EB_scf_if(COND, RESULT_TYPES)
#define EB_scf_if(...) \
for (auto &&eb_mlir_if_scope__ : \
::mlir::easybuild::makeScfIfLikeRange(__VA_ARGS__)) \
if (eb_mlir_if_scope__ == 0)
#define EB_else else

} // namespace easybuild
} // namespace mlir
#endif
6 changes: 6 additions & 0 deletions mlir/unittests/Dialect/Arith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
add_mlir_unittest(MLIRArithTests
EasyBuildTest.cpp)
target_link_libraries(MLIRArithTests
PRIVATE
MLIRFuncDialect
MLIRArithDialect)
Loading