Skip to content

Commit

Permalink
[RUNTIME] Enable ext_dev type for quick plugin of device (#542)
Browse files Browse the repository at this point in the history
* [RUNTIME] Enable ext_dev type for quick plugin of device

* [TEST] Update testcase to cover all computation
  • Loading branch information
tqchen authored Oct 12, 2017
1 parent 581509a commit acd48e9
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 15 deletions.
5 changes: 5 additions & 0 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
Var b = args[1];
*rv = a + b;
});

TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
} // namespace tvm_ext
20 changes: 20 additions & 0 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import tvm_ext
import tvm
import numpy as np

def test_bind_add():
def add(a, b):
return a + b
f = tvm_ext.bind_add(add, 1)
assert f(2) == 3

def test_ext_dev():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
ctx = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
check_llvm()


def test_sym_add():
a = tvm.var('a')
b = tvm.var('b')
Expand All @@ -26,6 +45,7 @@ def ivec_cb(v2):
tvm.convert(ivec_cb)(ivec)

if __name__ == "__main__":
test_ext_dev()
test_ext_vec()
test_bind_add()
test_sym_add()
3 changes: 3 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ typedef int64_t tvm_index_t;

/*! \brief Extension device types in TVM */
typedef enum {
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import target

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev

from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure):
4 : 'opencl',
8 : 'metal',
9 : 'vpi',
10: 'rocm'
10: 'rocm',
12: 'ext_dev',
}
STR2MASK = {
'cpu': 1,
Expand All @@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
'opencl': 4,
'metal': 8,
'vpi': 9,
'rocm': 10
'rocm': 10,
'ext_dev': 12,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)

if not target.startswith("llvm") and target != "stackvm" and not fdevice:
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)

Expand Down
4 changes: 4 additions & 0 deletions python/tvm/contrib/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)

def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)

def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ def vpi(dev_id=0):
"""
return TVMContext(9, dev_id)

def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)


cl = opencl
mtl = metal

Expand Down
1 change: 1 addition & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ inline std::string DeviceName(int type) {
case kMetal: return "metal";
case kVPI: return "vpi";
case kROCM: return "rocm";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;

ROCMThreadEntry::ROCMThreadEntry()
: pool(kGPU, ROCMDeviceAPI::Global()) {
: pool(kROCM, ROCMDeviceAPI::Global()) {
}

ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
Expand Down
20 changes: 10 additions & 10 deletions tests/python/unittest/test_codegen_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_add_pipeline():
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
s = tvm.create_schedule(D.op)

# GPU schedule have to split by gridIdx and threadIdx
Expand All @@ -26,11 +26,11 @@ def test_add_pipeline():
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])

Expand All @@ -49,10 +49,10 @@ def check_target(device, host="stackvm"):
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host):
Expand All @@ -73,10 +73,10 @@ def check_module_save(device, host="stackvm"):
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

check_target("cuda", host="stackvm")
check_target("cuda", host="llvm")
Expand Down

0 comments on commit acd48e9

Please sign in to comment.