Skip to content

Commit

Permalink
[CODEGEN/EXEC] CUDA, NVRTC pipeline complete (apache#27)
Browse files Browse the repository at this point in the history
* [CODEGEN] CUDA/OPENCL pipeline complete

* Hide TVMType by str in frontend
  • Loading branch information
tqchen authored Jan 31, 2017
1 parent ff06917 commit 891630e
Show file tree
Hide file tree
Showing 50 changed files with 1,751 additions and 520 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ language: cpp

os:
- linux
# - osx
- osx

env:
# code analysis
Expand Down
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 30bf0f to 642ae5
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ all: lib/libtvm.a lib/libtvm.so

LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a

SRC = $(wildcard src/*.cc src/*/*.cc)
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)

Expand All @@ -39,7 +39,7 @@ endif

ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart
LDFLAGS += -lcuda -lcudart -lnvrtc
else
CFLAGS += -DTVM_CUDA_RUNTIME=0
endif
Expand Down Expand Up @@ -92,3 +92,4 @@ clean:

-include build/*.d
-include build/*/*.d
-include build/*/*/*.d
16 changes: 0 additions & 16 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,9 @@
#include <string>
#include <memory>
#include <functional>
#include <typeinfo>
#include <type_traits>

namespace tvm {

/*!
*\brief whether to use CUDA runtime
*/
#ifndef TVM_CUDA_RUNTIME
#define TVM_CUDA_RUNTIME 1
#endif

/*!
*\brief whether to use opencl runtime
*/
#ifndef TVM_OPENCL_RUNTIME
#define TVM_OPENCL_RUNTIME 0
#endif

using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
Expand Down
36 changes: 34 additions & 2 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"
#include "./lowered_func.h"
#include "./runtime/packed_func.h"


namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
// use packed function from runtime.
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down Expand Up @@ -64,8 +69,35 @@ Array<Var> UndefinedVars(const LoweredFunc& f);
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);

/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
* \param device_funcs The additional device functions
* \return A packed function representing the func.
*/
PackedFunc BuildStackVM(
LoweredFunc func,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);

/*!
* \brief Build a CUDA function with NVRTC
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
*/
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode);

runtime::PackedFunc BuildStackVM(LoweredFunc func);
/*!
* \brief Build a OpenCL function.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
*/
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode);

} // namespace codegen
} // namespace tvm
Expand Down
15 changes: 14 additions & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <string>
#include <algorithm>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./runtime/c_runtime_api.h"

namespace tvm {

Expand All @@ -33,6 +33,19 @@ using Halide::Internal::Variable;

using Halide::Internal::make_const;


inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}

inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}

/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/module.h → include/tvm/lowered_func.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
* Copyright (c) 2017 by Contributors
* \file lowered_func.h
* \brief Information about a lowered TVM function.
* This data structure is final step toward codegen.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_

#include <tvm/container.h>
#include <ir/FunctionBase.h>
Expand Down Expand Up @@ -102,4 +102,13 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const {

} // namespace tvm

#endif // TVM_MODULE_H_
namespace std {
template <>
struct hash<::tvm::LoweredFunc> {
std::size_t operator()(const ::tvm::LoweredFunc& k) const {
return k.hash();
}
};
}

#endif // TVM_LOWERED_FUNC_H_
15 changes: 2 additions & 13 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "./base.h"
#include "./expr.h"
#include "./runtime/packed_func.h"

namespace tvm {
using runtime::TVMArgs;
Expand Down Expand Up @@ -162,19 +163,7 @@ inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLI
type_codes_[i] = kNodeHandle;
}

// Type related stuffs
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}

inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}

// type related stuffs
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
return this->operator=(Type2TVMType(t));
}
Expand Down
23 changes: 23 additions & 0 deletions include/tvm/runtime/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*!
* Copyright (c) 2017 by Contributors
* \file config.h
* \brief Runtime library related configurations.
*/
#ifndef TVM_RUNTIME_CONFIG_H_
#define TVM_RUNTIME_CONFIG_H_

/*!
*\brief whether to use CUDA runtime
*/
#ifndef TVM_CUDA_RUNTIME
#define TVM_CUDA_RUNTIME 1
#endif

/*!
*\brief whether to use opencl runtime
*/
#ifndef TVM_OPENCL_RUNTIME
#define TVM_OPENCL_RUNTIME 0
#endif

#endif // TVM_RUNTIME_CONFIG_H_
38 changes: 36 additions & 2 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ inline const char* TypeCode2Str(int type_code);
*/
inline TVMType String2TVMType(std::string s);

/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string TVMType2String(TVMType t);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
Expand Down Expand Up @@ -258,6 +265,9 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator TVMArray*;
// conversion operator.
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
Expand Down Expand Up @@ -308,7 +318,6 @@ class TVMRetValue : public TVMPODValue_ {
*/
TVMRetValue(TVMRetValue&& other)
: TVMPODValue_(other.value_, other.type_code_) {
other.type_code_ = kNull;
}
/*! \brief destructor */
~TVMRetValue() {
Expand All @@ -328,6 +337,9 @@ class TVMRetValue : public TVMPODValue_ {
}
// conversion operators
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
}
Expand Down Expand Up @@ -418,6 +430,13 @@ class TVMRetValue : public TVMPODValue_ {
*ret_type_code = type_code_;
type_code_ = kNull;
}
/*! \return The value field, if the data is POD */
const TVMValue& value() const {
CHECK(type_code_ != kNodeHandle &&
type_code_ != kFuncHandle &&
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
Expand Down Expand Up @@ -488,7 +507,7 @@ inline const char* TypeCode2Str(int type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kHandle: return "Handle";
case kHandle: return "handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
Expand All @@ -499,6 +518,21 @@ inline const char* TypeCode2Str(int type_code) {
}
}

inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
os << TypeCode2Str(t.code)
<< static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}

inline std::string TVMType2String(TVMType t) {
std::ostringstream os;
os << t;
return os.str();
}

inline TVMType String2TVMType(std::string s) {
TVMType t;
t.bits = 32; t.lanes = 1;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from . import schedule

from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, init_opencl
from .ndarray import cpu, gpu, opencl, init_opencl, cl

from ._base import TVMError
from .api import *
4 changes: 2 additions & 2 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _make_tvm_args(args, temp_args):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_type = arg
type_codes[i] = TypeCode.TVM_TYPE
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/_ctypes/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ class TVMValue(ctypes.Union):
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", TVMType)]
("v_str", ctypes.c_char_p)]


TVMPackedCFunc = ctypes.CFUNCTYPE(
Expand Down Expand Up @@ -117,7 +116,6 @@ def _return_handle(x):
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.TVM_TYPE: lambda x: x.v_type,
TypeCode.STR: lambda x: py_str(x.v_str)
}

Expand All @@ -127,6 +125,5 @@ def _return_handle(x):
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.TVM_TYPE: lambda x: x.v_type,
TypeCode.STR: lambda x: py_str(x.v_str)
}
6 changes: 3 additions & 3 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from . import expr as _expr
from . import collections as _collections

int32 = TVMType("int32")
float32 = TVMType("float32")
handle = TVMType("handle")
int32 = "int32"
float32 = "float32"
handle = "handle"

def const(value, dtype=None):
"""construct a constant"""
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
class Array(NodeBase):
"""Array container of TVM"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
return [self[idx] for idx in range(start, stop, step)]

if i >= len(self):
raise IndexError("array index out ot range")
return _api_internal._ArrayGetItem(self, i)
Expand Down
Loading

0 comments on commit 891630e

Please sign in to comment.