Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] Schedule Ops init working version #6

Merged
merged 2 commits into from
Jan 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 5d1bd1 to 1ec478
3 changes: 3 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace tvm {

using Halide::Type;
using Halide::Float;
using Halide::Bool;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
Expand All @@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;

using Halide::Internal::make_const;

/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
Expand Down
18 changes: 10 additions & 8 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
namespace tvm {
namespace ir {


/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand Down Expand Up @@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);

} // namespace ir
} // namespace tvm

Expand Down
43 changes: 41 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@

namespace tvm {

/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;

int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
Expand All @@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}

size_t num_outputs() const final {
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;

Expand All @@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");

/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
Expand Down
20 changes: 11 additions & 9 deletions src/schedule/bound.h → include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
* \file schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_BOUND_H_
#define TVM_SCHEDULE_BOUND_H_
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include "./base.h"
#include "./schedule.h"

namespace tvm {
namespace schedule {
Expand All @@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);

} // namespace schedule
} // namespace tvm

#endif // TVM_SCHEDULE_BOUND_H_
#endif // TVM_SCHEDULE_PASS_H_
41 changes: 12 additions & 29 deletions include/tvm/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public FunctionRef {
class Tensor : public NodeRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
};

/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
class Operation : public FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand All @@ -137,12 +128,10 @@ class Operation : public NodeRef {
};

/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Expand All @@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {

void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index);
Expand All @@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)


def placeholder(shape, dtype = None, name="TensorObj"):
def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object.

Parameters
Expand All @@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor
"""
dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, 0)
return _function_internal._Placeholder(
shape, dtype, name)


def compute(shape, fcompute, name="TensorCompute"):
def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain.

The compute rule is result[axis] = fcompute(axis)
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __call__(self, *indices):
else:
raise ValueError("The indices must be expression")

return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)

def __getitem__(self, indices):
return TensorSlice(self, indices)
Expand Down Expand Up @@ -71,3 +73,7 @@ def output(self, index):
@register_node
class ComputeOp(Operation):
pass

@register_node
class PlaceholderOp(Operation):
pass
14 changes: 12 additions & 2 deletions src/c_api/c_api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});

TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});


TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
Expand Down Expand Up @@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
Expand Down
8 changes: 7 additions & 1 deletion src/c_api/c_api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
Expand All @@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor()));
});

TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});

TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);

} // namespace ir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/c_api/c_api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(IterVarNode);


} // namespace tvm
3 changes: 2 additions & 1 deletion src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
Expand Down
Loading