-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [PASS] copy intrin * update comment thanks to derisavi
- Loading branch information
Showing
8 changed files
with
286 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \brief Replace certain copy with copy intrinsics. | ||
* \file copy_intrin_rewrite.cc | ||
*/ | ||
#include <tvm/ir.h> | ||
#include <tvm/packed_func_ext.h> | ||
#include <tvm/ir_mutator.h> | ||
#include <tvm/ir_pass.h> | ||
|
||
namespace tvm { | ||
namespace ir { | ||
|
||
using runtime::PackedFunc; | ||
|
||
class CopyIntrinInjector : public IRMutator { | ||
public: | ||
CopyIntrinInjector(const std::string& pragma_key, | ||
const PackedFunc& flower_copy_fromto) | ||
: pragma_key_(pragma_key), | ||
flower_copy_fromto_(flower_copy_fromto) { | ||
} | ||
|
||
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { | ||
if (op->attr_key == attr::storage_scope) { | ||
const Variable* buf = op->node.as<Variable>(); | ||
storage_scope_[buf] = op->value.as<StringImm>()->value; | ||
} else if (op->attr_key == ir::attr::pragma_scope) { | ||
const std::string& pname = op->value.as<StringImm>()->value; | ||
if (pname == pragma_key_) { | ||
Stmt ret; | ||
CHECK(MatchCopyPattern(op->body, &ret)) | ||
<< "Cannot match copy pattern of " << op->body; | ||
return ret; | ||
} | ||
} | ||
return IRMutator::Mutate_(op, s); | ||
} | ||
|
||
private: | ||
bool MatchCopyPattern(Stmt stmt, Stmt *out) { | ||
Stmt body = stmt; | ||
|
||
// strip the loops | ||
std::vector<const For*> loops; | ||
while (const For* op = body.as<For>()) { | ||
if (!is_zero(op->min)) return false; | ||
loops.push_back(op); | ||
body = op->body; | ||
} | ||
const Store* store = body.as<Store>(); | ||
if (store == nullptr) return false; | ||
const Select* select = store->value.as<Select>(); | ||
const Load* load = store->value.as<Load>(); | ||
|
||
// for now only support true condition matching | ||
if (select != nullptr) { | ||
load = select->true_value.as<Load>(); | ||
} | ||
if (load == nullptr) return false; | ||
if (load->type.lanes() != 1) return false; | ||
Array<Var> loop_vars; | ||
for (const For* op : loops) { | ||
loop_vars.push_back(Var(op->loop_var.node_)); | ||
} | ||
Array<Expr> store_strides = | ||
arith::DetectLinearEquation(store->index, loop_vars); | ||
Array<Expr> load_strides = | ||
arith::DetectLinearEquation(load->index, loop_vars); | ||
if (load_strides.size() == 0 || store_strides.size() == 0) return false; | ||
Array<Expr> dst_shape; | ||
for (const For* op : loops) { | ||
dst_shape.push_back(op->extent); | ||
} | ||
Array<Expr> src_shape = dst_shape; | ||
Array<Expr> pad_before, pad_after; | ||
Expr pad_value; | ||
Expr src_elem_offset = load_strides[loop_vars.size()]; | ||
if (select != nullptr) { | ||
Array<Expr> clip_bound = | ||
arith::DetectClipBound(select->condition, loop_vars); | ||
pad_value = select->false_value; | ||
if (clip_bound.size() == 0) return false; | ||
CHECK_EQ(src_shape.size(), loop_vars.size()); | ||
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); | ||
for (size_t i = 0; i < src_shape.size(); ++i) { | ||
Expr min_value = clip_bound[2 * i]; | ||
Expr max_value = clip_bound[2 * i + 1]; | ||
Type t = loop_vars[i].type(); | ||
Expr svalue = src_shape[i]; | ||
if (min_value.defined()) { | ||
Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); | ||
src_elem_offset = src_elem_offset + pbefore * load_strides[i]; | ||
svalue = svalue - pbefore; | ||
pad_before.push_back(pbefore); | ||
} else { | ||
pad_before.push_back(make_zero(t)); | ||
} | ||
if (max_value.defined()) { | ||
Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1), | ||
make_zero(t))); | ||
svalue = svalue - pafter; | ||
pad_after.push_back(pafter); | ||
} else { | ||
pad_after.push_back(make_zero(t)); | ||
} | ||
src_shape.Set(i, Simplify(svalue)); | ||
} | ||
src_elem_offset = Simplify(src_elem_offset); | ||
} | ||
CHECK_EQ(load_strides.size(), store_strides.size()); | ||
CHECK_EQ(load_strides.size(), loop_vars.size() + 1); | ||
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size()); | ||
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); | ||
Buffer dst = BufferNode::make( | ||
Var(store->buffer_var.node_), | ||
load->type, | ||
dst_shape, | ||
dst_strides, | ||
store_strides[loop_vars.size()], | ||
store->buffer_var->name_hint, | ||
GetStorageScope(store->buffer_var.get()), | ||
0, 0); | ||
Buffer src = BufferNode::make( | ||
Var(load->buffer_var.node_), | ||
load->type, | ||
src_shape, | ||
src_strides, | ||
src_elem_offset, | ||
load->buffer_var->name_hint, | ||
GetStorageScope(load->buffer_var.get()), | ||
0, 0); | ||
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); | ||
CHECK(out->defined()) << "flower function did not return correct stmt"; | ||
return true; | ||
} | ||
// Get storage scope | ||
std::string GetStorageScope(const Variable* var) const { | ||
auto it = storage_scope_.find(var); | ||
if (it != storage_scope_.end()) { | ||
return it->second; | ||
} else { | ||
return ""; | ||
} | ||
} | ||
// pragma key | ||
const std::string& pragma_key_; | ||
// function to lower copy intrinsics. | ||
const PackedFunc& flower_copy_fromto_; | ||
// Storage scope | ||
std::unordered_map<const Variable*, std::string> storage_scope_; | ||
}; | ||
|
||
Stmt InjectCopyIntrin(Stmt stmt, | ||
const std::string& pragma_key, | ||
const PackedFunc& flower_copy_fromto) { | ||
return CopyIntrinInjector(pragma_key, flower_copy_fromto) | ||
.Mutate(stmt); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import tvm | ||
|
||
def test_copy2d(): | ||
m = tvm.var('m') | ||
l = tvm.var('l') | ||
A = tvm.placeholder((m, l), name='A') | ||
B = tvm.compute((m, l), lambda i, j: A[i, j], name='B') | ||
s = tvm.create_schedule(B.op) | ||
s[B].pragma(B.op.axis[0], "memcpy") | ||
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') | ||
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') | ||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) | ||
def cb(src, dst, pad_before, pad_after, pad_value): | ||
assert dst.strides[0] == l | ||
assert dst.strides[1].value == 1 | ||
assert src.strides[0] == l | ||
assert tuple(src.shape) == (m, l) | ||
return tvm.make.Evaluate(0) | ||
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) | ||
|
||
def test_copy_pad(): | ||
m = tvm.var('m') | ||
l = tvm.var('l') | ||
A = tvm.placeholder((m, l), name='A') | ||
B = tvm.compute((m + 2, l), lambda i, j: | ||
tvm.select(tvm.all(i >= 1, i < m + 1), | ||
A[i - 1, j], 1.0), name='B') | ||
s = tvm.create_schedule(B.op) | ||
s[B].pragma(B.op.axis[0], "memcpy") | ||
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') | ||
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') | ||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) | ||
def cb(src, dst, pad_before, pad_after, pad_value): | ||
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 | ||
assert pad_before[0].value == 1 | ||
assert pad_before[1].value == 0 | ||
assert pad_after[0].value == 1 | ||
assert pad_after[1].value == 0 | ||
assert pad_value.value == 1.0 | ||
return tvm.make.Evaluate(0) | ||
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) | ||
|
||
def assert_expr_equal(a, b): | ||
assert tvm.ir_pass.Simplify(a - b).value == 0 | ||
|
||
def test_copy_pad_split(): | ||
m = 4 * 3 | ||
A = tvm.placeholder((m, ), name="A") | ||
Apad = tvm.compute((m + 2,), lambda i: | ||
tvm.select(tvm.all(i >= 1, i <= m), | ||
A[i - 1], 0.0), "Apad") | ||
B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) | ||
s = tvm.create_schedule(B.op) | ||
xo, xi = s[B].split(B.op.axis[0], factor=4) | ||
s[Apad].compute_at(s[B], xo) | ||
s[Apad].pragma(s[Apad].op.axis[0], "memcpy") | ||
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') | ||
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') | ||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) | ||
stmt = tvm.ir_pass.CanonicalSimplify(stmt) | ||
def cb(src, dst, pad_before, pad_after, pad_value): | ||
assert(dst.elem_offset.value == 0) | ||
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) | ||
rpad_before = tvm.max(1 - xo * 4, 0) | ||
rpad_after = tvm.max(xo * 4 - 7, 0) | ||
assert_expr_equal(pad_before[0], rpad_before) | ||
assert_expr_equal(pad_after[0], rpad_after) | ||
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) | ||
return tvm.make.Evaluate(0) | ||
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_copy2d() | ||
test_copy_pad() | ||
test_copy_pad_split() |