Skip to content

Commit

Permalink
[PASS] copy intrin (#536)
Browse files Browse the repository at this point in the history
* [PASS] copy intrin

* update comment thanks to derisavi
  • Loading branch information
tqchen authored Oct 11, 2017
1 parent 33a80e4 commit 581509a
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 6 deletions.
4 changes: 2 additions & 2 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ struct IntSetNode : public Node {
};

/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n]
* Where coeff and base are invariant of var.
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class BufferNode : public Node {
Type dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr byte_offset,
Expr elem_offset,
std::string name,
std::string scope,
int data_alignment,
Expand Down
18 changes: 18 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ Stmt InjectPrefetch(Stmt stmt);
*/
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);

/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statment to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return Transformed stmt.
*/
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const runtime::PackedFunc& fintrin);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
Expand Down
8 changes: 6 additions & 2 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
}

inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle;
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle;
} else {
type_codes_[i] = kNull;
}
}

// type related stuffs
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS3(InjectCopyIntrin);
REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
Expand Down
15 changes: 14 additions & 1 deletion src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,13 @@ class Canonical::Internal : public IRMutator {
if (!op->is_pure()) {
stack_.back().has_side_effect = true;
}
return IRMutator::Mutate_(op, e);
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
} else {
return expr;
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) {
Expand All @@ -320,6 +326,13 @@ class Canonical::Internal : public IRMutator {
--level_counter_;
return stmt;
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<IfThenElse>();
if (is_one(op->condition)) return op->then_case;
return stmt;
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
Expand Down
162 changes: 162 additions & 0 deletions src/pass/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*!
* Copyright (c) 2017 by Contributors
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>

namespace tvm {
namespace ir {

using runtime::PackedFunc;

class CopyIntrinInjector : public IRMutator {
public:
CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto)
: pragma_key_(pragma_key),
flower_copy_fromto_(flower_copy_fromto) {
}

Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value;
} else if (op->attr_key == ir::attr::pragma_scope) {
const std::string& pname = op->value.as<StringImm>()->value;
if (pname == pragma_key_) {
Stmt ret;
CHECK(MatchCopyPattern(op->body, &ret))
<< "Cannot match copy pattern of " << op->body;
return ret;
}
}
return IRMutator::Mutate_(op, s);
}

private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;

// strip the loops
std::vector<const For*> loops;
while (const For* op = body.as<For>()) {
if (!is_zero(op->min)) return false;
loops.push_back(op);
body = op->body;
}
const Store* store = body.as<Store>();
if (store == nullptr) return false;
const Select* select = store->value.as<Select>();
const Load* load = store->value.as<Load>();

// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
}
if (load == nullptr) return false;
if (load->type.lanes() != 1) return false;
Array<Var> loop_vars;
for (const For* op : loops) {
loop_vars.push_back(Var(op->loop_var.node_));
}
Array<Expr> store_strides =
arith::DetectLinearEquation(store->index, loop_vars);
Array<Expr> load_strides =
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()];
if (select != nullptr) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
pad_value = select->false_value;
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr min_value = clip_bound[2 * i];
Expr max_value = clip_bound[2 * i + 1];
Type t = loop_vars[i].type();
Expr svalue = src_shape[i];
if (min_value.defined()) {
Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
src_elem_offset = src_elem_offset + pbefore * load_strides[i];
svalue = svalue - pbefore;
pad_before.push_back(pbefore);
} else {
pad_before.push_back(make_zero(t));
}
if (max_value.defined()) {
Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1),
make_zero(t)));
svalue = svalue - pafter;
pad_after.push_back(pafter);
} else {
pad_after.push_back(make_zero(t));
}
src_shape.Set(i, Simplify(svalue));
}
src_elem_offset = Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
load->type,
dst_shape,
dst_strides,
store_strides[loop_vars.size()],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
Buffer src = BufferNode::make(
Var(load->buffer_var.node_),
load->type,
src_shape,
src_strides,
src_elem_offset,
load->buffer_var->name_hint,
GetStorageScope(load->buffer_var.get()),
0, 0);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
CHECK(out->defined()) << "flower function did not return correct stmt";
return true;
}
// Get storage scope
std::string GetStorageScope(const Variable* var) const {
auto it = storage_scope_.find(var);
if (it != storage_scope_.end()) {
return it->second;
} else {
return "";
}
}
// pragma key
const std::string& pragma_key_;
// function to lower copy intrinsics.
const PackedFunc& flower_copy_fromto_;
// Storage scope
std::unordered_map<const Variable*, std::string> storage_scope_;
};

Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)
.Mutate(stmt);
}

} // namespace ir
} // namespace tvm
82 changes: 82 additions & 0 deletions tests/python/unittest/test_pass_inject_copy_intrin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import tvm

def test_copy2d():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert dst.strides[0] == l
assert dst.strides[1].value == 1
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def test_copy_pad():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m + 2, l), lambda i, j:
tvm.select(tvm.all(i >= 1, i < m + 1),
A[i - 1, j], 1.0), name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert pad_before[0].value == 1
assert pad_before[1].value == 0
assert pad_after[0].value == 1
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0

def test_copy_pad_split():
m = 4 * 3
A = tvm.placeholder((m, ), name="A")
Apad = tvm.compute((m + 2,), lambda i:
tvm.select(tvm.all(i >= 1, i <= m),
A[i - 1], 0.0), "Apad")
B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=4)
s[Apad].compute_at(s[B], xo)
s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)
rpad_before = tvm.max(1 - xo * 4, 0)
rpad_after = tvm.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before)
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)


if __name__ == "__main__":
test_copy2d()
test_copy_pad()
test_copy_pad_split()

0 comments on commit 581509a

Please sign in to comment.