Skip to content

Commit 891630e

Browse files
authored
[CODEGEN/EXEC] CUDA, NVRTC pipeline complete (#27)
* [CODEGEN] CUDA/OPENCL pipeline complete * Hide TVMType by str in frontend
1 parent ff06917 commit 891630e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1751
-520
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ language: cpp
44

55
os:
66
- linux
7-
# - osx
7+
- osx
88

99
env:
1010
# code analysis

HalideIR

Submodule HalideIR updated from 30bf0f0 to 642ae50

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ all: lib/libtvm.a lib/libtvm.so
1515

1616
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
1717

18-
SRC = $(wildcard src/*.cc src/*/*.cc)
18+
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
1919
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
2020
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
2121

@@ -39,7 +39,7 @@ endif
3939

4040
ifeq ($(USE_CUDA), 1)
4141
CFLAGS += -DTVM_CUDA_RUNTIME=1
42-
LDFLAGS += -lcuda -lcudart
42+
LDFLAGS += -lcuda -lcudart -lnvrtc
4343
else
4444
CFLAGS += -DTVM_CUDA_RUNTIME=0
4545
endif
@@ -92,3 +92,4 @@ clean:
9292

9393
-include build/*.d
9494
-include build/*/*.d
95+
-include build/*/*/*.d

include/tvm/base.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,9 @@
1212
#include <string>
1313
#include <memory>
1414
#include <functional>
15-
#include <typeinfo>
16-
#include <type_traits>
1715

1816
namespace tvm {
1917

20-
/*!
21-
*\brief whether to use CUDA runtime
22-
*/
23-
#ifndef TVM_CUDA_RUNTIME
24-
#define TVM_CUDA_RUNTIME 1
25-
#endif
26-
27-
/*!
28-
*\brief whether to use opencl runtime
29-
*/
30-
#ifndef TVM_OPENCL_RUNTIME
31-
#define TVM_OPENCL_RUNTIME 0
32-
#endif
33-
3418
using ::tvm::Node;
3519
using ::tvm::NodeRef;
3620
using ::tvm::AttrVisitor;

include/tvm/codegen.h

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99
#include <string>
1010
#include "./base.h"
1111
#include "./expr.h"
12-
#include "./module.h"
12+
#include "./lowered_func.h"
1313
#include "./runtime/packed_func.h"
1414

1515

1616
namespace tvm {
1717
/*! \brief namespace for lowlevel IR pass and codegen */
1818
namespace codegen {
19+
// use packed function from runtime.
20+
using runtime::PackedFunc;
21+
using runtime::TVMArgs;
22+
using runtime::TVMRetValue;
23+
1924
/*!
2025
* \brief Make an user callable API LoweredFunc.
2126
*
@@ -64,8 +69,35 @@ Array<Var> UndefinedVars(const LoweredFunc& f);
6469
*/
6570
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
6671

72+
/*!
73+
* \brief Build a stack VM function.
74+
* \param func The LoweredFunc to be build
75+
* \param device_funcs The additional device functions
76+
* \return A packed function representing the func.
77+
*/
78+
PackedFunc BuildStackVM(
79+
LoweredFunc func,
80+
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
81+
82+
/*!
83+
* \brief Build a CUDA function with NVRTC
84+
*
85+
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
86+
* The first element is the host function, followed by device functions.
87+
* \param host_mode The host side compilation mode:
88+
* - "stackvm": use stack vm to interpret host side code.
89+
*/
90+
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode);
6791

68-
runtime::PackedFunc BuildStackVM(LoweredFunc func);
92+
/*!
93+
* \brief Build a OpenCL function.
94+
*
95+
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
96+
* The first element is the host function, followed by device functions.
97+
* \param host_mode The host side compilation mode:
98+
* - "stackvm": use stack vm to interpret host side code.
99+
*/
100+
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode);
69101

70102
} // namespace codegen
71103
} // namespace tvm

include/tvm/expr.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <string>
1313
#include <algorithm>
1414
#include "./base.h"
15-
#include "./runtime/packed_func.h"
15+
#include "./runtime/c_runtime_api.h"
1616

1717
namespace tvm {
1818

@@ -33,6 +33,19 @@ using Halide::Internal::Variable;
3333

3434
using Halide::Internal::make_const;
3535

36+
37+
inline Type TVMType2Type(TVMType t) {
38+
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
39+
}
40+
41+
inline TVMType Type2TVMType(Type t) {
42+
TVMType ret;
43+
ret.code = static_cast<uint8_t>(t.code());
44+
ret.bits = static_cast<uint8_t>(t.bits());
45+
ret.lanes = static_cast<uint16_t>(t.lanes());
46+
return ret;
47+
}
48+
3649
/*! \brief a named variable in TVM */
3750
class Var : public Halide::VarExpr {
3851
public:

include/tvm/module.h renamed to include/tvm/lowered_func.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
/*!
2-
* Copyright (c) 2016 by Contributors
3-
* \file module.h
4-
* \brief Low level IR module,
5-
* Contains lowered function information.
2+
* Copyright (c) 2017 by Contributors
3+
* \file lowered_func.h
4+
* \brief Information about a lowered TVM function.
5+
* This data structure is final step toward codegen.
66
*/
7-
#ifndef TVM_MODULE_H_
8-
#define TVM_MODULE_H_
7+
#ifndef TVM_LOWERED_FUNC_H_
8+
#define TVM_LOWERED_FUNC_H_
99

1010
#include <tvm/container.h>
1111
#include <ir/FunctionBase.h>
@@ -102,4 +102,13 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const {
102102

103103
} // namespace tvm
104104

105-
#endif // TVM_MODULE_H_
105+
namespace std {
106+
template <>
107+
struct hash<::tvm::LoweredFunc> {
108+
std::size_t operator()(const ::tvm::LoweredFunc& k) const {
109+
return k.hash();
110+
}
111+
};
112+
}
113+
114+
#endif // TVM_LOWERED_FUNC_H_

include/tvm/packed_func_ext.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "./base.h"
1616
#include "./expr.h"
17+
#include "./runtime/packed_func.h"
1718

1819
namespace tvm {
1920
using runtime::TVMArgs;
@@ -162,19 +163,7 @@ inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLI
162163
type_codes_[i] = kNodeHandle;
163164
}
164165

165-
// Type related stuffs
166-
inline Type TVMType2Type(TVMType t) {
167-
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
168-
}
169-
170-
inline TVMType Type2TVMType(Type t) {
171-
TVMType ret;
172-
ret.code = static_cast<uint8_t>(t.code());
173-
ret.bits = static_cast<uint8_t>(t.bits());
174-
ret.lanes = static_cast<uint16_t>(t.lanes());
175-
return ret;
176-
}
177-
166+
// type related stuffs
178167
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
179168
return this->operator=(Type2TVMType(t));
180169
}

include/tvm/runtime/config.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file config.h
4+
* \brief Runtime library related configurations.
5+
*/
6+
#ifndef TVM_RUNTIME_CONFIG_H_
7+
#define TVM_RUNTIME_CONFIG_H_
8+
9+
/*!
10+
*\brief whether to use CUDA runtime
11+
*/
12+
#ifndef TVM_CUDA_RUNTIME
13+
#define TVM_CUDA_RUNTIME 1
14+
#endif
15+
16+
/*!
17+
*\brief whether to use opencl runtime
18+
*/
19+
#ifndef TVM_OPENCL_RUNTIME
20+
#define TVM_OPENCL_RUNTIME 0
21+
#endif
22+
23+
#endif // TVM_RUNTIME_CONFIG_H_

include/tvm/runtime/packed_func.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ inline const char* TypeCode2Str(int type_code);
163163
*/
164164
inline TVMType String2TVMType(std::string s);
165165

166+
/*!
167+
* \brief convert a TVM type to string.
168+
* \param t The type to be converted.
169+
* \return The corresponding tvm type in string.
170+
*/
171+
inline std::string TVMType2String(TVMType t);
172+
166173
// macro to check type code.
167174
#define TVM_CHECK_TYPE_CODE(CODE, T) \
168175
CHECK_EQ(CODE, T) << " expected " \
@@ -258,6 +265,9 @@ class TVMArgValue : public TVMPODValue_ {
258265
using TVMPODValue_::operator TVMArray*;
259266
// conversion operator.
260267
operator std::string() const {
268+
if (type_code_ == kTVMType) {
269+
return TVMType2String(operator TVMType());
270+
}
261271
TVM_CHECK_TYPE_CODE(type_code_, kStr);
262272
return std::string(value_.v_str);
263273
}
@@ -308,7 +318,6 @@ class TVMRetValue : public TVMPODValue_ {
308318
*/
309319
TVMRetValue(TVMRetValue&& other)
310320
: TVMPODValue_(other.value_, other.type_code_) {
311-
other.type_code_ = kNull;
312321
}
313322
/*! \brief destructor */
314323
~TVMRetValue() {
@@ -328,6 +337,9 @@ class TVMRetValue : public TVMPODValue_ {
328337
}
329338
// conversion operators
330339
operator std::string() const {
340+
if (type_code_ == kTVMType) {
341+
return TVMType2String(operator TVMType());
342+
}
331343
TVM_CHECK_TYPE_CODE(type_code_, kStr);
332344
return *ptr<std::string>();
333345
}
@@ -418,6 +430,13 @@ class TVMRetValue : public TVMPODValue_ {
418430
*ret_type_code = type_code_;
419431
type_code_ = kNull;
420432
}
433+
/*! \return The value field, if the data is POD */
434+
const TVMValue& value() const {
435+
CHECK(type_code_ != kNodeHandle &&
436+
type_code_ != kFuncHandle &&
437+
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
438+
return value_;
439+
}
421440
// NodeRef related extenstions: in tvm/packed_func_ext.h
422441
inline TVMRetValue& operator=(const NodeRef& other);
423442
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
@@ -488,7 +507,7 @@ inline const char* TypeCode2Str(int type_code) {
488507
case kInt: return "int";
489508
case kFloat: return "float";
490509
case kStr: return "str";
491-
case kHandle: return "Handle";
510+
case kHandle: return "handle";
492511
case kNull: return "NULL";
493512
case kNodeHandle: return "NodeHandle";
494513
case kArrayHandle: return "ArrayHandle";
@@ -499,6 +518,21 @@ inline const char* TypeCode2Str(int type_code) {
499518
}
500519
}
501520

521+
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
522+
os << TypeCode2Str(t.code)
523+
<< static_cast<int>(t.bits);
524+
if (t.lanes != 1) {
525+
os << 'x' << static_cast<int>(t.lanes);
526+
}
527+
return os;
528+
}
529+
530+
inline std::string TVMType2String(TVMType t) {
531+
std::ostringstream os;
532+
os << t;
533+
return os.str();
534+
}
535+
502536
inline TVMType String2TVMType(std::string s) {
503537
TVMType t;
504538
t.bits = 32; t.lanes = 1;

0 commit comments

Comments
 (0)