Skip to content

Commit ba28c58

Browse files
committed
[ADDON] Allow piggy back nvcc compiler and code
1 parent 8837798 commit ba28c58

File tree

10 files changed

+161
-19
lines changed

10 files changed

+161
-19
lines changed

include/tvm/runtime/c_runtime_api.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ typedef enum {
5151
kArrayHandle = 5U,
5252
kTVMType = 6U,
5353
kNodeHandle = 7U,
54-
kStr = 8U,
55-
kFuncHandle = 9U
54+
kFuncHandle = 8U,
55+
kStr = 9U,
56+
kBytes = 10U
5657
} TVMTypeCode;
5758

5859
/*!
@@ -86,6 +87,15 @@ typedef union {
8687
TVMType v_type;
8788
} TVMValue;
8889

90+
/*!
91+
* \brief Byte array type used to pass in byte array
92+
* When kBytes is used as data type.
93+
*/
94+
typedef struct {
95+
const char* data;
96+
size_t size;
97+
} TVMByteArray;
98+
8999
/*!
90100
* \brief The device type
91101
*/

include/tvm/runtime/packed_func.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ class PackedFunc {
111111
* \return reference to the registered function.
112112
*/
113113
static const PackedFunc& GetGlobal(const std::string& name);
114+
/*!
115+
* \brief Whether the global function exist
116+
* \param name The name of the function.
117+
* \return Whetehr the global function exist.
118+
*/
119+
static bool GlobalExist(const std::string& name);
114120
/*!
115121
* \brief Get the names of currently registered global function.
116122
*/
@@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
267273
operator std::string() const {
268274
if (type_code_ == kTVMType) {
269275
return TVMType2String(operator TVMType());
276+
} else if (type_code_ == kBytes) {
277+
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
278+
return std::string(arr->data, arr->size);
279+
} else {
280+
TVM_CHECK_TYPE_CODE(type_code_, kStr);
281+
return std::string(value_.v_str);
270282
}
271-
TVM_CHECK_TYPE_CODE(type_code_, kStr);
272-
return std::string(value_.v_str);
273283
}
274284
operator TVMType() const {
275285
if (type_code_ == kStr) {
@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
452462
template<typename T>
453463
void Assign(const T& other) {
454464
switch (other.type_code()) {
455-
case kStr: {
465+
case kStr:
466+
case kBytes: {
456467
SwitchToClass<std::string>(kStr, other);
457468
break;
458469
}

python/tvm/_ctypes/_function.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .._base import _LIB, check_call
1111
from .._base import c_str, py_str, string_types
12-
from ._types import TVMValue, TypeCode, TVMType
12+
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
1313
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
1414
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
1515
from ._node import NodeBase, SliceBase, convert_to_node
@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
9292
elif isinstance(arg, TVMType):
9393
values[i].v_str = c_str(str(arg))
9494
type_codes[i] = TypeCode.STR
95+
elif isinstance(arg, bytearray):
96+
arr = TVMByteArray()
97+
arr.data = ctypes.cast(
98+
(ctypes.c_byte * len(arg)).from_buffer(arg),
99+
ctypes.POINTER(ctypes.c_byte))
100+
arr.size = len(arg)
101+
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
102+
temp_args.append(arr)
103+
type_codes[i] = TypeCode.BYTES
95104
elif isinstance(arg, string_types):
96105
values[i].v_str = c_str(arg)
97106
type_codes[i] = TypeCode.STR

python/tvm/_ctypes/_types.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class TypeCode(object):
1818
ARRAY_HANDLE = 5
1919
TVM_TYPE = 6
2020
NODE_HANDLE = 7
21-
STR = 8
22-
FUNC_HANDLE = 9
21+
FUNC_HANDLE = 8
22+
STR = 9
23+
BYTES = 10
2324

2425
def _api_type(code):
2526
"""create a type accepted by API"""
@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
8889
("v_handle", ctypes.c_void_p),
8990
("v_str", ctypes.c_char_p)]
9091

92+
class TVMByteArray(ctypes.Structure):
93+
"""TVM datatype structure"""
94+
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
95+
("size", ctypes.c_size_t)]
96+
9197

9298
TVMPackedCFunc = ctypes.CFUNCTYPE(
9399
None,
@@ -110,20 +116,34 @@ def _return_handle(x):
110116
handle = ctypes.c_void_p(handle)
111117
return handle
112118

119+
def _return_bytes(x):
120+
"""return handle"""
121+
handle = x.v_handle
122+
if not isinstance(handle, ctypes.c_void_p):
123+
handle = ctypes.c_void_p(handle)
124+
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
125+
size = arr.size
126+
res = bytearray(size)
127+
rptr = (ctypes.c_byte * size).from_buffer(res)
128+
if not ctypes.memmove(rptr, arr.data, size):
129+
raise RuntimeError('memmove failed')
130+
return res
131+
113132

114133
RETURN_SWITCH = {
115134
TypeCode.INT: lambda x: x.v_int64,
116135
TypeCode.FLOAT: lambda x: x.v_float64,
117136
TypeCode.HANDLE: _return_handle,
118137
TypeCode.NULL: lambda x: None,
119-
TypeCode.STR: lambda x: py_str(x.v_str)
138+
TypeCode.STR: lambda x: py_str(x.v_str),
139+
TypeCode.BYTES: _return_bytes
120140
}
121141

122-
123142
C_TO_PY_ARG_SWITCH = {
124143
TypeCode.INT: lambda x: x.v_int64,
125144
TypeCode.FLOAT: lambda x: x.v_float64,
126145
TypeCode.HANDLE: _return_handle,
127146
TypeCode.NULL: lambda x: None,
128-
TypeCode.STR: lambda x: py_str(x.v_str)
147+
TypeCode.STR: lambda x: py_str(x.v_str),
148+
TypeCode.BYTES: _return_bytes
129149
}

python/tvm/addon/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Addon utilities to python"""

python/tvm/addon/nvcc_compiler.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Util to compile with NVCC"""
2+
import os
3+
import sys
4+
import tempfile
5+
import subprocess
6+
7+
def compile_source(code, target="cubin"):
8+
"""Compile cuda code with NVCC from env.
9+
10+
Parameters
11+
----------
12+
code : str
13+
The cuda code.
14+
15+
target: str
16+
The target format
17+
18+
Return
19+
------
20+
cubin : bytearray
21+
The bytearray of the cubin
22+
"""
23+
temp_dir = tempfile.mkdtemp()
24+
if target not in ["cubin", "ptx", "fatbin"]:
25+
raise ValueError("target must be in cubin, ptx, fatbin")
26+
path_code = os.path.join(temp_dir, "my_kernel.cu")
27+
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)
28+
29+
with open(path_code, "w") as out_file:
30+
out_file.write(code)
31+
32+
cmd = ["nvcc"]
33+
cmd += ["--%s" % target, "-O3"]
34+
cmd += ["-o", path_target]
35+
cmd += [path_code]
36+
args = ' '.join(cmd)
37+
38+
proc = subprocess.Popen(
39+
args, shell=True,
40+
stdout=subprocess.PIPE,
41+
stderr=subprocess.STDOUT)
42+
(out, _) = proc.communicate()
43+
44+
if proc.returncode != 0:
45+
sys.stderr.write("Compilation error:\n")
46+
sys.stderr.write(out)
47+
sys.stderr.flush()
48+
cubin = None
49+
else:
50+
cubin = bytearray(open(path_target, "rb").read())
51+
os.remove(path_code)
52+
if os.path.exists(path_target):
53+
os.remove(path_target)
54+
os.rmdir(temp_dir)
55+
return cubin

src/codegen/codegen_cuda.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
5050
os << CodeGenCUDA().Compile(f, output_ssa);
5151
os << '\n';
5252
}
53-
std::string ptx = runtime::NVRTCCompile(os.str());
53+
std::string code = os.str();
54+
55+
if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
56+
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
57+
code = f(code).operator std::string();
58+
}
59+
std::string ptx;
60+
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
61+
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
62+
ptx = f(code).operator std::string();
63+
} else {
64+
ptx = runtime::NVRTCCompile(os.str());
65+
}
5466
std::unordered_map<LoweredFunc, PackedFunc> ret;
5567

5668
runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);

src/runtime/packed_func_registry.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
4646
return *(it->second);
4747
}
4848

49+
bool PackedFunc::GlobalExist(const std::string& name) {
50+
PackedFuncRegistry* r = PackedFuncRegistry::Global();
51+
auto it = r->fmap.find(name);
52+
return it != r->fmap.end();
53+
}
54+
4955
std::vector<std::string> PackedFunc::ListGlobalNames() {
5056
PackedFuncRegistry* r = PackedFuncRegistry::Global();
5157
std::vector<std::string> keys;

tests/python/integration/test_gemm.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
import tvm
2+
from tvm.addon import nvcc_compiler
23
import numpy as np
34

5+
@tvm.register_func
6+
def tvm_callback_cuda_compile(code):
7+
ptx = nvcc_compiler.compile_source(code, target="ptx")
8+
print(ptx.decode("utf-8"))
9+
return ptx
10+
11+
@tvm.register_func
12+
def tvm_callback_cuda_postproc(code):
13+
print(code)
14+
return code
15+
416
def test_gemm():
517
# graph
618
nn = 1024
@@ -23,7 +35,6 @@ def test_gemm():
2335
s = tvm.Schedule(C.op)
2436
xtile, ytile = 32, 32
2537
s[AA].set_scope("shared")
26-
#s[CC].set_scope("global")
2738
s[BB].set_scope("shared")
2839

2940
scale = 8
@@ -60,8 +71,6 @@ def check_device(target):
6071
codes = []
6172
f = tvm.build(s, [A, B, C], target, record_codes=codes,
6273
max_auto_unroll_step=max_auto_unroll_step)
63-
for c in codes[1:]:
64-
print(c)
6574
if target == "cuda":
6675
ctx = tvm.gpu(0)
6776
else:
@@ -77,13 +86,14 @@ def check_device(target):
7786
a = tvm.nd.array(a_np, ctx)
7887
b = tvm.nd.array(b_np, ctx)
7988
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
80-
f(a, b, c)
89+
for i in range(4):
90+
f(a, b, c)
8191
np.testing.assert_allclose(
8292
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
8393

84-
tvm.init_opencl()
8594
check_device("cuda")
86-
check_device("opencl")
95+
#tvm.init_opencl()
96+
#check_device("opencl")
8797

8898
if __name__ == "__main__":
8999
test_gemm()

tests/python/unittest/test_runtime_packed_func.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,17 @@ def myfunc(*args):
3535
assert isinstance(f, tvm.nd.Function)
3636
f(*targs)
3737

38+
def test_byte_array():
39+
s = "hello"
40+
a = bytearray(s, encoding="ascii")
41+
42+
def myfunc(ss):
43+
assert ss == a
44+
f = tvm.convert(myfunc)
45+
f(a)
3846

3947
if __name__ == "__main__":
40-
test_function()
4148
test_convert()
4249
test_get_global()
4350
test_return_func()
51+
test_byte_array()

0 commit comments

Comments
 (0)