Skip to content

Commit

Permalink
Rename gpu to cuda, and bump dlpack to v0.5 (apache#8032)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin authored May 13, 2021
1 parent ed283b8 commit 43c2ea7
Show file tree
Hide file tree
Showing 68 changed files with 217 additions and 197 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
Submodule dlpack updated 1 files
+5 −10 include/dlpack/dlpack.h
10 changes: 5 additions & 5 deletions apps/topi_recipe/broadcast/test_broadcast_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_broadcast_to(in_shape, out_shape):
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)

data_nd = tvm.nd.array(data_npy, tvm.gpu())
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.gpu())
data_nd = tvm.nd.array(data_npy, tvm.cuda())
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.cuda())
for _ in range(2):
fcuda(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
Expand Down Expand Up @@ -116,9 +116,9 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
lhs_nd = tvm.nd.array(lhs_npy, tvm.gpu())
rhs_nd = tvm.nd.array(rhs_npy, tvm.gpu())
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.gpu())
lhs_nd = tvm.nd.array(lhs_npy, tvm.cuda())
rhs_nd = tvm.nd.array(rhs_npy, tvm.cuda())
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.cuda())
for _ in range(2):
fcuda(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/reduce/test_reduce_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
else:
raise NotImplementedError

data_tvm = tvm.nd.array(in_npy, device=tvm.gpu())
out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.gpu())
data_tvm = tvm.nd.array(in_npy, device=tvm.cuda())
out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.cuda())

for _ in range(2):
fcuda(data_tvm, out_tvm)
Expand Down
2 changes: 1 addition & 1 deletion apps/topi_recipe/rnn/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def lstm():
def check_device(target):
num_step = n_num_step
flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
scan_h_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
scan_c_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
Expand Down
10 changes: 5 additions & 5 deletions apps/topi_recipe/rnn/matexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def check_device(target):
}
):
f = tvm.build(s, [s_scan, Whh], target)
dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
res_np = np.zeros((n_num_step, n_batch_size, n_num_hidden)).astype("float32")
Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
Expand All @@ -160,16 +160,16 @@ def check_device(target):
print("Time cost=%g" % tgap)
# correctness
if not SKIP_CHECK:
res_gpu = res_a.asnumpy()
res_cuda = res_a.asnumpy()
res_cmp = np.ones_like(res_np).astype("float64")
Whh_np = Whh_np.astype("float64")
for t in range(1, n_num_step):
res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
for i in range(n_num_step):
for j in range(n_num_hidden):
if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
if abs(res_cmp[i, 0, j] - res_cuda[i, 0, j]) > 1e-5:
print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_cuda[i, 0, j]))
tvm.testing.assert_allclose(res_cuda, res_cmp, rtol=1e-3)

check_device("cuda")

Expand Down
2 changes: 1 addition & 1 deletion docs/deploy/tensorrt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ have to be built.

.. code:: python
dev = tvm.gpu(0)
dev = tvm.cuda(0)
loaded_lib = tvm.runtime.load_module('compiled.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
input_data = np.random.uniform(0, 1, input_shape).astype(dtype)
Expand Down
6 changes: 3 additions & 3 deletions docs/dev/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
import tvm
# Example runtime execution program in python, with type annotated
mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.gpu(0))
arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0))
fun: tvm.runtime.PackedFunc = mod["addone"]
fun(a)
print(a.asnumpy())
Expand All @@ -164,8 +164,8 @@ The above example only deals with a simple `addone` function. The code snippet b
import tvm
# Example runtime execution program in python, with types annotated
factory: tvm.runtime.Module = tvm.runtime.load_module("resnet18.so")
# Create a stateful graph execution module for resnet18 on gpu(0)
gmod: tvm.runtime.Module = factory["resnet18"](tvm.gpu(0))
# Create a stateful graph execution module for resnet18 on cuda(0)
gmod: tvm.runtime.Module = factory["resnet18"](tvm.cuda(0))
data: tvm.runtime.NDArray = get_input_data()
# set input
gmod["set_input"](0, data)
Expand Down
20 changes: 10 additions & 10 deletions golang/src/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import "C"

// KDLCPU is golang enum correspond to TVM device type kDLCPU.
var KDLCPU = int32(C.kDLCPU)
// KDLGPU is golang enum correspond to TVM device type kDLGPU.
var KDLGPU = int32(C.kDLGPU)
// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned.
var KDLCPUPinned = int32(C.kDLCPUPinned)
// kDLCUDA is golang enum correspond to TVM device type kDLCUDA.
var kDLCUDA = int32(C.kDLCUDA)
// kDLCUDAHost is golang enum correspond to TVM device type kDLCUDAHost.
var kDLCUDAHost = int32(C.kDLCUDAHost)
// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL.
var KDLOpenCL = int32(C.kDLOpenCL)
// KDLMetal is golang enum correspond to TVM device type kDLMetal.
Expand Down Expand Up @@ -61,14 +61,14 @@ func CPU(index int32) Device {
return Device{KDLCPU, index}
}

// GPU returns the Device object for GPU target on given index
func GPU(index int32) Device {
return Device{KDLGPU, index}
// CUDA returns the Device object for CUDA target on given index
func CUDA(index int32) Device {
return Device{kDLCUDA, index}
}

// CPUPinned returns the Device object for CPUPinned target on given index
func CPUPinned(index int32) Device {
return Device{KDLCPUPinned, index}
// CUDAHost returns the Device object for CUDAHost target on given index
func CUDAHost(index int32) Device {
return Device{kDLCUDAHost, index}
}

// OpenCL returns the Device object for OpenCL target on given index
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLGPU:
return "gpu";
case kDLCPUPinned:
return "cpu_pinned";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, ext_dev, micro_dev, hexagon
from .runtime import ndarray as nd

Expand Down
3 changes: 1 addition & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class Device(ctypes.Structure):
_fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = {
1: "cpu",
2: "gpu",
2: "cuda",
4: "opencl",
5: "aocl",
6: "sdaccel",
Expand All @@ -182,7 +182,6 @@ class Device(ctypes.Structure):
"stackvm": 1,
"cpu": 1,
"c": 1,
"gpu": 2,
"cuda": 2,
"nvptx": 2,
"cl": 4,
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def get_target_compute_version(target=None):
return major + "." + minor

# 3. GPU
if tvm.gpu(0).exist:
return tvm.gpu(0).compute_version
if tvm.cuda(0).exist:
return tvm.cuda(0).compute_version

warnings.warn(
"No CUDA architecture was specified or GPU detected."
Expand Down Expand Up @@ -331,8 +331,8 @@ def have_tensorcore(compute_version=None, target=None):
isn't specified.
"""
if compute_version is None:
if tvm.gpu(0).exist:
compute_version = tvm.gpu(0).compute_version
if tvm.cuda(0).exist:
compute_version = tvm.cuda(0).compute_version
else:
if target is None or "arch" not in target.attrs:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# function exposures
from .object_generic import convert_to_object, convert, const
from .ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
from .container import String
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-import, redefined-outer-name
"""Runtime NDArray API"""
import ctypes
import warnings
import numpy as np
import tvm._ffi

Expand Down Expand Up @@ -254,8 +255,7 @@ def device(dev_type, dev_id=0):
.. code-block:: python
assert tvm.device("cpu", 1) == tvm.cpu(1)
assert tvm.device("gpu", 0) == tvm.gpu(0)
assert tvm.device("cuda", 0) == tvm.gpu(0)
assert tvm.device("cuda", 0) == tvm.cuda(0)
"""
if isinstance(dev_type, string_types):
if "-device=micro_dev" in dev_type:
Expand Down Expand Up @@ -362,9 +362,27 @@ def cpu(dev_id=0):
return Device(1, dev_id)


def cuda(dev_id=0):
"""Construct a CUDA GPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
dev : Device
The created device
"""
return Device(2, dev_id)


def gpu(dev_id=0):
"""Construct a GPU device
"""Construct a CUDA GPU device
deprecated:: 0.9.0
Use :py:func:`tvm.cuda` instead.
Parameters
----------
dev_id : int, optional
Expand All @@ -375,6 +393,9 @@ def gpu(dev_id=0):
dev : Device
The created device
"""
warnings.warn(
"Please use tvm.cuda() instead of tvm.gpu(). tvm.gpu() is going to be deprecated in 0.9.0",
)
return Device(2, dev_id)


Expand Down
10 changes: 5 additions & 5 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ def _compose(args, decs):


def uses_gpu(*args):
"""Mark to differentiate tests that use the GPU is some capacity.
"""Mark to differentiate tests that use the GPU in some capacity.
These tests will be run on CPU-only test nodes and on test nodes with GPUS.
These tests will be run on CPU-only test nodes and on test nodes with GPUs.
To mark a test that must have a GPU present to run, use
:py:func:`tvm.testing.requires_gpu`.
Expand All @@ -490,7 +490,7 @@ def requires_gpu(*args):
Function to mark
"""
_requires_gpu = [
pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU present"),
pytest.mark.skipif(not tvm.cuda().exist, reason="No GPU present"),
*uses_gpu(),
]
return _compose(args, _requires_gpu)
Expand All @@ -499,7 +499,7 @@ def requires_gpu(*args):
def requires_cuda(*args):
"""Mark a test as requiring the CUDA runtime.
This also marks the test as requiring a gpu.
This also marks the test as requiring a cuda gpu.
Parameters
----------
Expand Down Expand Up @@ -618,7 +618,7 @@ def requires_tensorcore(*args):
_requires_tensorcore = [
pytest.mark.tensorcore,
pytest.mark.skipif(
not tvm.gpu().exist or not nvcc.have_tensorcore(tvm.gpu(0).compute_version),
not tvm.cuda().exist or not nvcc.have_tensorcore(tvm.cuda(0).compute_version),
reason="No tensorcore present",
),
*requires_gpu(),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda":
assert data_layout == "HWNC" and kernel_layout == "HWOI"
assert float(tvm.gpu(0).compute_version) >= 7.5
assert float(tvm.cuda(0).compute_version) >= 7.5
H, W, N, CI = get_const_tuple(data.shape)
KH, KW, CO, _ = get_const_tuple(kernel.shape)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def non_max_suppression(
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "cuda")
dev = tvm.gpu(0)
dev = tvm.cuda(0)
tvm_data = tvm.nd.array(np_data, dev)
tvm_valid_count = tvm.nd.array(np_valid_count, dev)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev)
Expand Down
8 changes: 4 additions & 4 deletions rust/tvm-sys/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use thiserror::Error;
pub enum DeviceType {
CPU = 1,
GPU,
CPUPinned,
CUDAHost,
OpenCL,
Vulkan,
Metal,
Expand Down Expand Up @@ -101,8 +101,8 @@ impl Display for DeviceType {
"{}",
match self {
DeviceType::CPU => "cpu",
DeviceType::GPU => "gpu",
DeviceType::CPUPinned => "cpu_pinned",
DeviceType::GPU => "cuda",
DeviceType::CUDAHost => "cuda_host",
DeviceType::OpenCL => "opencl",
DeviceType::Vulkan => "vulkan",
DeviceType::Metal => "metal",
Expand Down Expand Up @@ -210,7 +210,7 @@ macro_rules! impl_tvm_device {

impl_tvm_device!(
DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
DLDeviceType_kDLVPI: [vpi],
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-sys/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ macro_rules! impl_tvm_device {

impl_tvm_device!(
DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
DLDeviceType_kDLVPI: [vpi],
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/search_policy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ inline bool IsCPUTask(const SearchTask& task) {

/*! \brief Return whether the search task is targeting a GPU. */
inline bool IsGPUTask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLGPU ||
return (task)->target->kind->device_type == kDLCUDA ||
(task)->target->kind->device_type == kDLOpenCL ||
(task)->target->kind->device_type == kDLVulkan ||
(task)->target->kind->device_type == kDLMetal ||
Expand All @@ -63,7 +63,7 @@ inline bool IsGPUTask(const SearchTask& task) {

/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
return (task)->target->kind->device_type == kDLGPU;
return (task)->target->kind->device_type == kDLCUDA;
}

/*! \brief Return whether the search task is targeting a OpenCL GPU. */
Expand Down
6 changes: 3 additions & 3 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
const auto device_type = target->kind->device_type;
if (device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
} else if (device_type == kDLGPU || device_type == kDLROCM) {
} else if (device_type == kDLCUDA || device_type == kDLROCM) {
auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
auto device_name = device_type == kDLCUDA ? "device_api.cuda" : "device_api.rocm";
auto func = tvm::runtime::Registry::Get(device_name);
ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
ICHECK(func != nullptr) << "Cannot find CUDA device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
Expand Down
Loading

0 comments on commit 43c2ea7

Please sign in to comment.