Skip to content

Commit

Permalink
[IR] Include PrefetchIR (apache#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 19, 2017
1 parent eaf0fde commit 1400eda
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 10 deletions.
2 changes: 1 addition & 1 deletion HalideIR
6 changes: 6 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type";
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
Expand Down Expand Up @@ -371,6 +376,7 @@ using Halide::Internal::Provide;
using Halide::Internal::Allocate;
using Halide::Internal::Free;
using Halide::Internal::Realize;
using Halide::Internal::Prefetch;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace ir {
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* This helps you to avoid to book-keep return value of Visitor via state,
* which can cause bugs easily when state is incorrectly maintained.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
Expand Down Expand Up @@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
Expand All @@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class IRMutator {
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class IRVisitor {
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Prefetch* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,16 @@ class IterVarAttrNode : public Node {
IterVarType iter_type{kDataPar};
/*! \brief The thread this iter Var binds, can be null */
IterVar bind_thread;
/*! \brief List of tensor to be prefetched in this loop */
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
Array<Expr> prefetch_offset;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
}

static constexpr const char* _type_key = "IterVarAttr";
Expand Down
1 change: 1 addition & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src.
- pass The optimization pass on the IR structure
- codegen The code generator.
- runtime Minimum runtime related codes
- contrib Contrib extension libraries
2 changes: 1 addition & 1 deletion src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void BoundDeducer::Deduce() {
success = false;
return;
}
// get the sign of every subexpr

expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
Expand Down
20 changes: 18 additions & 2 deletions src/op/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage,
// Mark the iter var in the IR, to remember the point
if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
IterVarAttr it_attr;
if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) {
it_attr = stage->iter_var_attrs[iv];
}
if (it_attr.defined()) {
switch (it_attr->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type
<< it_attr->iter_type
<< " in the iter_var_attrs";
}
}
Expand All @@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
CHECK(!is_one(dom->extent))
<< "Cannot prefetch on trivial loop with extent=1";
CHECK_EQ(it_attr->prefetch_data.size(),
it_attr->prefetch_offset.size());
for (size_t i = 0; i < it_attr->prefetch_data.size(); ++i) {
nest[i + 1].emplace_back(
AttrStmt::make(it_attr->prefetch_data[i],
ir::attr::prefetch_scope,
it_attr->prefetch_offset[i], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
Expand Down
2 changes: 1 addition & 1 deletion src/pass/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace tvm {
namespace ir {

// If expression is touched by var.
class ExprTouched : public IRVisitor {
class ExprTouched final : public IRVisitor {
public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
: touched_var_(touched) {}
Expand Down
2 changes: 1 addition & 1 deletion src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ir {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline : public IRMutator {
class IRInline final : public IRMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expand Down
25 changes: 25 additions & 0 deletions src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;

// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}

if (!bounds_changed) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
op->type, new_bounds);
}
}

Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
Expand Down
8 changes: 7 additions & 1 deletion src/pass/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) {
}

void IRVisitor::Visit_(const Realize *op) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
Expand All @@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) {
this->Visit(op->condition);
}

void IRVisitor::Visit_(const Prefetch *op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
}

void IRVisitor::Visit_(const Block *op) {
this->Visit(op->first);
this->Visit(op->rest);
Expand Down
2 changes: 1 addition & 1 deletion src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector : public IRVisitor {
class CandidateSelector final : public IRVisitor {
public:
using VarIsUsed = bool;
CandidateSelector() {}
Expand Down
2 changes: 1 addition & 1 deletion src/pass/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace tvm {
namespace ir {

class ThreadAllreduceBuilder : public IRMutator {
class ThreadAllreduceBuilder final : public IRMutator {
public:
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
Expand Down
2 changes: 1 addition & 1 deletion src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder : public IRVisitor {
class StorageAccessPatternFinder final : public IRVisitor {
public:
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
Expand Down

0 comments on commit 1400eda

Please sign in to comment.