Skip to content

Commit

Permalink
[LANG] Introduce Scan, Bugfix Canonical
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 13, 2017
1 parent f8f0282 commit 595dc94
Show file tree
Hide file tree
Showing 19 changed files with 776 additions and 117 deletions.
35 changes: 31 additions & 4 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
constexpr const char* scope = "scope";
struct TensorKey {
FunctionRef f;
int value_index;

inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};

/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
// The above attr does not pass to ir stage.
/*!
* \brief Mark launching extent of thread, used by device API.
*/
Expand Down Expand Up @@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
} // namespace ir
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std

#endif // TVM_IR_H_
64 changes: 64 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
};

/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);

static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
};


/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
Expand All @@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");

// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
Expand Down
55 changes: 52 additions & 3 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import _api_internal
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import collections as _collections

int32 = "int32"
Expand Down Expand Up @@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of *indices-> value
Specifies the input source expression
Expand All @@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return _api_internal._Tensor(
shape, body.dtype, op_node, 0)
return op_node.output(0)


def scan(axis, init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
Example
-------
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)


def Buffer(shape, dtype=None,
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,17 @@ def output(self, index):
"""
return _api_internal._OpGetOutput(self, index)

@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass

@register_node
class ComputeOp(Operation):
"""Compute operation."""
pass

@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
class ScanOp(Operation):
"""Scan operation."""
pass
9 changes: 9 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
args[2]);
});

TVM_REGISTER_API(_ScanOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});

TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
Expand Down
11 changes: 9 additions & 2 deletions src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
const ComExpr& sumb,
int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
n->base = suma->base + sumb->base;
n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb;
size_t i = 0, j = 0;
while (i < suma->elem.size() && j < sumb->elem.size()) {
Expand Down Expand Up @@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
// convert sum to expr
Expr Sum2Expr(const ComExpr& com, Type t) {
Expr vsum;
if (com->base != 0) {
if (com->base > 0) {
vsum = make_const(t, com->base);
}
for (const ComExprEntry& e : com->elem) {
Expand All @@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
}
}
}
if (com->base < 0) {
if (vsum.defined()) {
vsum = Sub::make(vsum, make_const(t, -com->base));
} else {
vsum = make_const(t, com->base);
}
}
for (const ComExprEntry& e : com->elem) {
if (e.scale < 0) {
Expr v = e.value;
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
LOG(INFO) << code;

std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
int max_auto_unroll_{1025};
};

} // namespace codegen
Expand Down
87 changes: 87 additions & 0 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>

namespace tvm {
Expand Down Expand Up @@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(ComputeOpNode);

// Scan
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}

int ScanOpNode::num_outputs() const {
return update.size();
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis};
}

Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}

Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}

Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());

for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}

n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}

Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});

} // namespace tvm
Loading

0 comments on commit 595dc94

Please sign in to comment.