diff --git a/CMakeLists.txt b/CMakeLists.txt
index 132b0e1c..c20759cc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -363,7 +363,6 @@ if(USE_CUDA)
endif()
list(APPEND SOURCE ${cuda_objs} ${CUDA})
add_definitions(-DMXNET_USE_CUDA=1)
- add_definitions(-DMXNET_USE_NVRTC=1)
if(CUDA_LIBRARY_PATH)
if(IS_CONTAINER_BUILD)
# In case of building on a production-like build container which may not have Cuda installed
diff --git a/Jenkinsfile b/Jenkinsfile
index d8cd2f5f..eb82e00f 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -370,7 +370,11 @@ try {
init_git()
unpack_lib('gpu')
timeout(time: max_time, unit: 'MINUTES') {
+ try {
sh "${docker_run} gpu ./perl-package/test.sh"
+ } catch (exc) {
+ error "Perl GPU test failed."
+ }
}
}
}
diff --git a/Makefile b/Makefile
index 54df33f1..56267279 100644
--- a/Makefile
+++ b/Makefile
@@ -272,7 +272,7 @@ ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(PLUGIN_OBJ) $(LIB_DEP)
ifeq ($(USE_CUDA), 1)
CFLAGS += -I$(ROOTDIR)/cub
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) $(PLUGIN_CUOBJ)
- LDFLAGS += -lcuda -lcufft
+ LDFLAGS += -lcuda -lcufft -lnvrtc
SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu
else
SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-cpu
@@ -281,13 +281,6 @@ endif
# For quick compile test, used smaller subset
ALLX_DEP= $(ALL_DEP)
-ifeq ($(USE_NVRTC), 1)
- LDFLAGS += -lnvrtc
- CFLAGS += -DMXNET_USE_NVRTC=1
-else
- CFLAGS += -DMXNET_USE_NVRTC=0
-endif
-
build/src/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 75aed075..e7f8d45a 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -134,3 +134,12 @@ imported by running:
metric/metric.md
```
+
+## Run-Time Compilation API
+
+```eval_rst
+.. toctree::
+ :maxdepth 1
+
+ rtc/rtc.md
+```
diff --git a/docs/api/python/rtc/rtc.md b/docs/api/python/rtc/rtc.md
new file mode 100644
index 00000000..bb1c3140
--- /dev/null
+++ b/docs/api/python/rtc/rtc.md
@@ -0,0 +1,29 @@
+# Run-Time Compilation API
+
+```eval_rst
+.. currentmodule:: mxnet.rtc
+```
+
+## Overview
+
+The RTC package contains tools for compiling and running CUDA code from python
+frontend. The compiled kernels can be used stand-alone or combined with
+`autograd.Function` or `operator.CustomOpProp` to support differentiation.
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ mxnet.rtc
+```
+
+## API Reference
+
+
+
+```eval_rst
+.. automodule:: mxnet.rtc
+ :members:
+```
+
+
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 1a2b82a3..4f4afa32 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -84,6 +84,10 @@ typedef void *KVStoreHandle;
typedef void *RecordIOHandle;
/*! \brief handle to MXRtc*/
typedef void *RtcHandle;
+/*! \brief handle to rtc cuda module*/
+typedef void *CudaModuleHandle;
+/*! \brief handle to rtc cuda kernel*/
+typedef void *CudaKernelHandle;
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
@@ -1922,6 +1926,59 @@ MXNET_DLL int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creato
MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
int num_outputs, NDArrayHandle *outputs,
struct MXCallbackList *callbacks);
+/*
+ * \brief create cuda rtc module
+ * \param source cuda source code
+ * \param num_options number of compiler flags
+ * \param options compiler flags
+ * \param num_exports number of exported function names
+ * \param exported function names
+ * \param out handle to created module
+ */
+MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options,
+ const char** options, int num_exports,
+ const char** exports, CudaModuleHandle *out);
+/*
+ * \brief delete cuda rtc module
+ * \param handle handle to cuda module
+ */
+MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle);
+/*
+ * \brief get kernel from module
+ * \param handle handle to cuda module
+ * \param name name of kernel function
+ * \param num_args number of arguments
+ * \param is_ndarray whether argument is ndarray
+ * \param is_const whether argument is constant
+ * \param arg_types data type of arguments
+ * \param out created kernel
+ */
+MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name,
+ int num_args, int* is_ndarray, int* is_const,
+ int* arg_types, CudaKernelHandle *out);
+/*
+ * \brief delete kernel
+ * \param handle handle to previously created kernel
+ */
+MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle);
+/*
+ * \brief launch cuda kernel
+ * \param handle handle to kernel
+ * \param dev_id (GPU) device id
+ * \param args pointer to arguments
+ * \param grid_dim_x grid dimension x
+ * \param grid_dim_y grid dimension y
+ * \param grid_dim_z grid dimension z
+ * \param block_dim_x block dimension x
+ * \param block_dim_y block dimension y
+ * \param block_dim_z block dimension z
+ * \param shared_mem size of dynamically allocated shared memory
+ */
+MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
+ mx_uint grid_dim_x, mx_uint grid_dim_y,
+ mx_uint grid_dim_z, mx_uint block_dim_x,
+ mx_uint block_dim_y, mx_uint block_dim_z,
+ mx_uint shared_mem);
#ifdef __cplusplus
}
diff --git a/include/mxnet/mxrtc.h b/include/mxnet/mxrtc.h
deleted file mode 100644
index 8d7facc5..00000000
--- a/include/mxnet/mxrtc.h
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mxrtc.h
- * \brief Wrapper for NVRTC
- * \author Junyuan Xie
- */
-#ifndef MXNET_MXRTC_H_
-#define MXNET_MXRTC_H_
-#include "./base.h"
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-#include
-#include
-
-#include
-#include
-#include
-#include
-#include
-#include "./ndarray.h"
-
-namespace mxnet {
-
-/*!
- * \brief Runtime compile of cuda kernel code with NVRTC
- */
-class MXRtc {
- public:
- /*!
- * \brief Build a new kernel.
- *
- * If the same kernel has been compiled before it will be load from
- * cache instead of compile again.
- * \param name name of the kernel function.
- * \param input list of input ndarrays and their name.
- * \param output list of output ndarrays and their name.
- * \param kernel cuda code.
- */
- MXRtc(const std::string& name,
- std::vector > const& input,
- std::vector > const& output,
- const std::string& kernel);
- /*!
- * \brief launch a kernel with the engine.
- * \param input list of input ndarray.
- * \param output list of output ndarray.
- * \param grid_dim_X kernel grid dimensions.
- * \param grid_dim_Y kernel grid dimensions.
- * \param grid_dim_Z kernel grid dimensions.
- * \param block_dim_X kernel block dimensions.
- * \param block_dim_Y kernel block dimensions.
- * \param block_dim_Z kernel block dimensions.
- */
- void push(std::vector const& input,
- std::vector const& output,
- unsigned int grid_dim_X,
- unsigned int grid_dim_Y,
- unsigned int grid_dim_Z,
- unsigned int block_dim_X,
- unsigned int block_dim_Y,
- unsigned int block_dim_Z);
-
- private:
- static const char str_type[];
- static std::unordered_map kernel_registry;
-
- std::string name_;
- index_t num_input_, num_output_;
- std::string code_;
- char* ptx_;
- std::unordered_map module_;
- std::unordered_map func_;
-
- /*!
- * \brief add supporting code to kernel.
- */
- std::string decorate(const std::string& name,
- std::vector > const& input,
- std::vector > const& output,
- const std::string kernel);
- /*!
- * \brief compile the kernel with nvrtc.
- */
- char* compile(const std::string& name, const std::string& code);
-};
-
-} // namespace mxnet
-
-#endif // MXNET_USE_CUDA && MXNET_USE_NVRTC
-#endif // MXNET_MXRTC_H_
diff --git a/include/mxnet/rtc.h b/include/mxnet/rtc.h
new file mode 100644
index 00000000..747c0b5c
--- /dev/null
+++ b/include/mxnet/rtc.h
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_RTC_H_
+#define MXNET_RTC_H_
+#include "./base.h"
+#if MXNET_USE_CUDA
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include "./ndarray.h"
+
+namespace mxnet {
+namespace rtc {
+
+/*! \brief Cuda runtime compile module. */
+class CudaModule {
+ private:
+ /*! \brief Structure for holding internal info. */
+ struct Chunk {
+ /*!
+ * \brief Constructs cuda module.
+ * \param source cuda source code.
+ * \param exports export symbols before mangling.
+ */
+ Chunk(const char* source,
+ const std::vector& options,
+ const std::vector& exports);
+ /*! \brief deconstrutor */
+ ~Chunk();
+ /*!
+ * \brief Get handle to cuda kernel from loaded module
+ * \param mangled_name mangled kernel name
+ * \param ctx context to run kernel on
+ * \return loaded function handle
+ */
+ CUfunction GetFunction(const std::string& mangled_name, const Context& ctx);
+ /*! \brief nvrtc program handle. */
+ nvrtcProgram prog_;
+ /*! \brief compiled cuda PTX */
+ char* ptx_;
+ /*! \brief lazily loaded cuda module */
+ std::unordered_map mod_;
+ /*! \brief exported names */
+ std::unordered_set exports_;
+ };
+ /*! \brief pointer to Chunk */
+ std::shared_ptr ptr_;
+
+ public:
+ /*! \brief cuda kernel argument descriptor */
+ struct ArgType {
+ /*! \brief whether argument is NDArray */
+ bool is_ndarray;
+ /*! \brief whether argument is constant (input) */
+ bool is_const;
+ /*! \brief data type of argument */
+ mshadow::TypeFlag dtype;
+ };
+ /*! \brief Cuda kernel */
+ class Kernel {
+ public:
+ /*! \brief Launch the kernel */
+ void Launch(const Context& ctx, const std::vector& args,
+ uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
+ uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
+ uint32_t shared_mem);
+ /*! \brief kernel interface signature */
+ const std::vector& signature() { return signature_; }
+
+ private:
+ friend class CudaModule;
+ /*!
+ * \brief constructor
+ * \param mod module of this kernel
+ * \param mangled_name mangled kernel name
+ * \param signature kernel argument signature
+ */
+ Kernel(const std::shared_ptr& mod,
+ const std::string& mangled_name,
+ const std::vector& signature);
+ /*! \brief mangled kernel name */
+ std::string mangled_name_;
+ /*! \brief kernel argument signature */
+ std::vector signature_;
+ /*! \brief module of this kernel */
+ std::shared_ptr mod_;
+ /*! \brief cached kernel function on each device */
+ std::unordered_map func_;
+ };
+ /*!
+ * \brief CudaModule constructor
+ * \param source cuda source code.
+ * \param exports export symbols before mangling.
+ */
+ CudaModule(const char* source,
+ const std::vector& options,
+ const std::vector& exports)
+ : ptr_(std::make_shared(source, options, exports)) {}
+ /*!
+ * \brief Get cuda kernal from module by name
+ * \param name kernel name
+ * \param signature kernel signature
+ * \return shared pointer to cuda kernel
+ */
+ std::shared_ptr GetKernel(const std::string& name,
+ const std::vector& signature);
+};
+
+} // namespace rtc
+} // namespace mxnet
+
+#endif // MXNET_USE_CUDA
+#endif // MXNET_RTC_H_
diff --git a/make/config.mk b/make/config.mk
index d44898bc..c5de8989 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -57,9 +57,6 @@ USE_CUDA_PATH = NONE
# whether use CuDNN R3 library
USE_CUDNN = 0
-# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
-USE_NVRTC = 0
-
# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
diff --git a/make/osx.mk b/make/osx.mk
index 650e284b..d9ce6f2d 100644
--- a/make/osx.mk
+++ b/make/osx.mk
@@ -51,9 +51,6 @@ USE_CUDA_PATH = NONE
# whether use CUDNN R3 library
USE_CUDNN = 0
-# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
-USE_NVRTC = 0
-
# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 72dc2b2f..cf0ba37a 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -54,7 +54,7 @@
from . import kvstore as kv
from . import kvstore_server
# Runtime compile module
-from .rtc import Rtc as rtc
+from . import rtc
# Attribute scope to add attributes to symbolic graphs
from .attribute import AttrScope
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index e422dade..fc07853b 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -125,6 +125,8 @@ def _load_lib():
KVStoreHandle = ctypes.c_void_p
RecordIOHandle = ctypes.c_void_p
RtcHandle = ctypes.c_void_p
+CudaModuleHandle = ctypes.c_void_p
+CudaKernelHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
diff --git a/python/mxnet/rtc.py b/python/mxnet/rtc.py
index 9da38c6a..aff4588b 100644
--- a/python/mxnet/rtc.py
+++ b/python/mxnet/rtc.py
@@ -18,91 +18,212 @@
"""Interface to runtime cuda kernel compile module."""
from __future__ import absolute_import
+import re
import ctypes
-from .base import _LIB, NDArrayHandle, RtcHandle, mx_uint, c_array, check_call
+import numpy as np
+
+from .base import _LIB, mx_uint, c_array, check_call
+from .base import c_str, CudaModuleHandle, CudaKernelHandle, numeric_types, string_types
+from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, NDArray
+
+_DTYPE_CPP_TO_NP = {
+ 'float': np.float32,
+ 'double': np.float64,
+ '__half': np.float16,
+ 'uint8_t': np.uint8,
+ 'int': np.int32,
+ 'int32_t': np.int32,
+ 'int8_t': np.int8,
+ 'char': np.int8,
+ 'int64_t': np.int64,
+}
+
+class CudaModule(object):
+ r"""Compile and run CUDA code from Python.
+
+ In CUDA 7.5, you need to prepend your kernel definitions
+ with 'extern "C"' to avoid name mangling::
+
+ source = r'''
+ extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
+ y[i] += alpha * x[i];
+ }
+ '''
+ module = mx.rtc.CudaModule(source)
+ func = module.get_kernel("axpy", "const float *x, float *y, float alpha")
+ x = mx.nd.ones((10,), ctx=mx.gpu(0))
+ y = mx.nd.zeros((10,), ctx=mx.gpu(0))
+ func.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+ print(y)
+
+ Starting from CUDA 8.0, you can instead export functions by name.
+ This also allows you to use templates::
+
+ source = r'''
+ template
+ __global__ void axpy(const DType *x, DType *y, DType alpha) {
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
+ y[i] += alpha * x[i];
+ }
+ '''
+ module = mx.rtc.CudaModule(source, exports=['axpy', 'axpy'])
+ func32 = module.get_kernel("axpy", "const float *x, float *y, float alpha")
+ x = mx.nd.ones((10,), dtype='float32', ctx=mx.gpu(0))
+ y = mx.nd.zeros((10,), dtype='float32', ctx=mx.gpu(0))
+ func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+ print(y)
+
+ func64 = module.get_kernel("axpy", "const double *x, double *y, double alpha")
+ x = mx.nd.ones((10,), dtype='float64', ctx=mx.gpu(0))
+ y = mx.nd.zeros((10,), dtype='float64', ctx=mx.gpu(0))
+ func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+ print(y)
-class Rtc(object):
- """MXRtc object in mxnet.
- This class allow you to write CUDA kernels in Python
- and call them with NDArray.
Parameters
----------
- name : str
- Name of the kernel.
- inputs : tuple of (str, mxnet.ndarray)
- List of input names and ndarray.
- outputs : tuple of (str, mxnet.ndarray)
- List of output names and ndarray.
- kernel : str
- The actual kernel code.
- Note that this is only the body of the kernel, i.e.
- after { and before }. Rtc will decorate the kernel.
- For example, if ``name = "mykernel"`` and
- inputs = [('x', mx.nd.zeros((10,)))]
- outputs = [('y', mx.nd.zeros((10,)))]
- kernel = "y[threadIdx.x] = x[threadIdx.x];",
- then the compiled kernel will be:
- extern "C" __global__ mykernel(float *x, float *y) {
- const int x_ndim = 1;
- const int x_dims = { 10 };
- const int y_ndim = 1;
- const int y_dims = { 10 };
-
- y[threadIdx.x] = x[threadIdx.x];
- }
+ source : str
+ Complete source code.
+ options : tuple of str
+ Compiler flags. For example, use "-I/usr/local/cuda/include" to
+ add cuda headers to include path.
+ exports : tuple of str
+ Export kernel names.
"""
- def __init__(self, name, inputs, outputs, kernel):
- self.handle = RtcHandle()
- input_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in inputs]),
- ctypes.POINTER(ctypes.c_char_p))
- output_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in outputs]),
- ctypes.POINTER(ctypes.c_char_p))
- input_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in inputs]),
- ctypes.POINTER(NDArrayHandle))
- output_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in outputs]),
- ctypes.POINTER(NDArrayHandle))
- check_call(_LIB.MXRtcCreate(ctypes.c_char_p(name),
- mx_uint(len(inputs)),
- mx_uint(len(outputs)),
- input_names,
- output_names,
- input_nds,
- output_nds,
- ctypes.c_char_p(kernel),
- ctypes.byref(self.handle)))
+ def __init__(self, source, options=(), exports=()):
+ if isinstance(options, string_types):
+ options = (options,)
+ if isinstance(exports, string_types):
+ exports = (exports,)
+ self.handle = CudaModuleHandle()
+ check_call(_LIB.MXRtcCudaModuleCreate(
+ c_str(source),
+ len(options),
+ c_array(ctypes.c_char_p, [c_str(opt) for opt in options]),
+ len(exports),
+ c_array(ctypes.c_char_p, [c_str(name) for name in exports]),
+ ctypes.byref(self.handle)))
def __del__(self):
- check_call(_LIB.MXRtcFree(self.handle))
+ check_call(_LIB.MXRtcCudaModuleFree(self.handle))
- def push(self, inputs, outputs, grid_dims, block_dims):
- """Run the kernel.
+ def get_kernel(self, name, signature):
+ r"""Get CUDA kernel from compiled module.
Parameters
----------
- inputs : list of NDArray
- List of inputs. Can contain different NDArrays than those used for the constructor,
- but its elements must have the same shapes and appear in the same order.
- outputs : list of NDArray
- List of outputs. Can contain different ndarrays than used for the constructor,
- but must have the same shapes and appear in the same order.
- grid_dims : tuple of 3 uint
- Grid dimension for kernel launch.
- block_dims : tuple of 3 uint
- Block dimension for kernel launch.
+ name : str
+ String name of the kernel.
+ signature : str
+ Function signature for the kernel. For example, if a kernel is
+ declared as::
+
+ extern "C" __global__ void axpy(const float *x, double *y, int alpha)
+
+ Then its signature should be::
+
+ const float *x, double *y, int alpha
+
+ or::
+
+ const float *, double *, int
+
+ Note that `*` in signature marks an argument as array and
+ `const` marks an argument as constant (input) array.
+
+ Returns
+ -------
+ CudaKernel
+ CUDA kernels that can be launched on GPUs.
+ """
+ hdl = CudaKernelHandle()
+ is_ndarray = []
+ is_const = []
+ dtypes = []
+ pattern = re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""")
+ args = re.sub(r"\s+", " ", signature).split(",")
+ for arg in args:
+ match = pattern.match(arg)
+ if not match or match.groups()[1] == 'const':
+ raise ValueError(
+ 'Invalid function prototype "%s". Must be in the '
+ 'form of "(const) type (*) (name)"'%arg)
+ is_const.append(bool(match.groups()[0]))
+ dtype = match.groups()[1]
+ is_ndarray.append(bool(match.groups()[2]))
+ if dtype not in _DTYPE_CPP_TO_NP:
+ raise TypeError(
+ "Unsupported kernel argument type %s. Supported types are: %s."%(
+ arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
+ dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]])
+
+ check_call(_LIB.MXRtcCudaKernelCreate(
+ self.handle,
+ c_str(name),
+ len(dtypes),
+ c_array(ctypes.c_int, [ctypes.c_int(i) for i in is_ndarray]),
+ c_array(ctypes.c_int, [ctypes.c_int(i) for i in is_const]),
+ c_array(ctypes.c_int, [ctypes.c_int(i) for i in dtypes]),
+ ctypes.byref(hdl)))
+
+ return CudaKernel(hdl, name, is_ndarray, dtypes)
+
+class CudaKernel(object):
+ """Constructs CUDA kernel. Should be created by `CudaModule.get_kernel`,
+ not intended to be used by users."""
+ def __init__(self, handle, name, is_ndarray, dtypes):
+ self.handle = handle
+ self._name = name
+ self._is_ndarray = is_ndarray
+ self._dtypes = [_DTYPE_MX_TO_NP[i] for i in dtypes]
+
+ def __del__(self):
+ check_call(_LIB.MXRtcCudaKernelFree(self.handle))
+
+ def launch(self, args, ctx, grid_dims, block_dims, shared_mem=0):
+ """Launch cuda kernel.
+
+ Parameters
+ ----------
+ args : tuple of NDArray or numbers
+ List of arguments for kernel. NDArrays are expected for pointer
+ types (e.g. `float*`, `double*`) while numbers are expected for
+ non-pointer types (e.g. `int`, `float`).
+ ctx : Context
+ The context to launch kernel on. Must be GPU context.
+ grid_dims : tuple of 3 integers
+ Grid dimensions for CUDA kernel.
+ block_dims : tuple of 3 integers
+ Block dimensions for CUDA kernel.
+ shared_mem : integer, optional
+ Size of dynamically allocated shared memory. Defaults to 0.
"""
- input_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in inputs]),
- ctypes.POINTER(NDArrayHandle))
- output_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in outputs]),
- ctypes.POINTER(NDArrayHandle))
- check_call(_LIB.MXRtcPush(self.handle,
- mx_uint(len(inputs)),
- mx_uint(len(outputs)),
- input_nds,
- output_nds,
- mx_uint(grid_dims[0]),
- mx_uint(grid_dims[1]),
- mx_uint(grid_dims[2]),
- mx_uint(block_dims[0]),
- mx_uint(block_dims[1]),
- mx_uint(block_dims[2])))
+ assert ctx.device_type == 'gpu', "Cuda kernel can only be launched on GPU"
+ assert len(grid_dims) == 3, "grid_dims must be a tuple of 3 integers"
+ assert len(block_dims) == 3, "grid_dims must be a tuple of 3 integers"
+ assert len(args) == len(self._dtypes), \
+ "CudaKernel(%s) expects %d arguments but got %d"%(
+ self._name, len(self._dtypes), len(args))
+ void_args = []
+ ref_holder = []
+ for i, (arg, is_nd, dtype) in enumerate(zip(args, self._is_ndarray, self._dtypes)):
+ if is_nd:
+ assert isinstance(arg, NDArray), \
+ "The %d-th argument is expected to be a NDArray but got %s"%(
+ i, type(arg))
+ void_args.append(arg.handle)
+ else:
+ assert isinstance(arg, numeric_types), \
+ "The %d-th argument is expected to be a number, but got %s"%(
+ i, type(arg))
+ ref_holder.append(np.array(arg, dtype=dtype))
+ void_args.append(ref_holder[-1].ctypes.data_as(ctypes.c_void_p))
+
+ check_call(_LIB.MXRtcCudaKernelCall(
+ self.handle,
+ ctx.device_id,
+ c_array(ctypes.c_void_p, void_args),
+ mx_uint(grid_dims[0]), mx_uint(grid_dims[1]), mx_uint(grid_dims[2]),
+ mx_uint(block_dims[0]), mx_uint(block_dims[1]), mx_uint(block_dims[2]),
+ mx_uint(shared_mem)))
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 29df7165..8ab7f1f2 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -33,7 +33,7 @@
#include
#include
#include
-#include
+#include
#include
#include
#include
@@ -1102,21 +1102,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
NDArrayHandle* inputs, NDArrayHandle* outputs,
char* kernel, RtcHandle *out) {
API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
- std::vector > input, output;
- for (mx_uint i = 0; i < num_input; ++i) {
- input.push_back(std::pair(input_names[i],
- *reinterpret_cast(inputs[i])));
- }
- for (mx_uint i = 0; i < num_output; ++i) {
- output.push_back(std::pair(output_names[i],
- *reinterpret_cast(outputs[i])));
- }
- MXRtc *rtc = new MXRtc(name, input, output, kernel);
- *out = reinterpret_cast(rtc);
-#else
- LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+ LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
@@ -1129,34 +1115,13 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
mx_uint blockDimY,
mx_uint blockDimZ) {
API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
- std::vector input, output;
- for (mx_uint i = 0; i < num_input; ++i) {
- input.push_back(*reinterpret_cast(inputs[i]));
- }
- for (mx_uint i = 0; i < num_output; ++i) {
- output.push_back(*reinterpret_cast(outputs[i]));
- }
- reinterpret_cast(handle)->push(input, output,
- gridDimX,
- gridDimY,
- gridDimZ,
- blockDimX,
- blockDimY,
- blockDimZ);
-#else
- LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+ LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
int MXRtcFree(RtcHandle handle) {
API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
- delete reinterpret_cast(handle);
-#else
- LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+ LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
API_END();
}
@@ -1165,3 +1130,87 @@ int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
mxnet::op::custom::Registry::Get()->Register(op_type, creator);
API_END();
}
+
+
+int MXRtcCudaModuleCreate(const char* source, int num_options,
+ const char** options, int num_exports,
+ const char** exports, CudaModuleHandle *out) {
+ API_BEGIN();
+#if MXNET_USE_CUDA
+ std::vector str_opts;
+ for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]);
+ std::vector str_exports;
+ for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]);
+ *out = new rtc::CudaModule(source, str_opts, str_exports);
+#else
+ LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+ API_END();
+}
+
+int MXRtcCudaModuleFree(CudaModuleHandle handle) {
+ API_BEGIN();
+#if MXNET_USE_CUDA
+ delete reinterpret_cast(handle);
+#else
+ LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+ API_END();
+}
+
+int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_args,
+ int* is_ndarray, int* is_const, int* arg_types,
+ CudaKernelHandle *out) {
+ API_BEGIN();
+#if MXNET_USE_CUDA
+ auto module = reinterpret_cast(handle);
+ std::vector signature;
+ for (int i = 0; i < num_args; ++i) {
+ signature.push_back(rtc::CudaModule::ArgType{
+ static_cast(is_ndarray[i]), static_cast(is_const[i]),
+ static_cast(arg_types[i])});
+ }
+ auto kernel = module->GetKernel(name, signature);
+ *out = new std::shared_ptr(kernel);
+#else
+ LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+ API_END();
+}
+
+int MXRtcCudaKernelFree(CudaKernelHandle handle) {
+ API_BEGIN();
+#if MXNET_USE_CUDA
+ delete reinterpret_cast*>(handle);
+#else
+ LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+ API_END();
+}
+
+int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
+ mx_uint grid_dim_x, mx_uint grid_dim_y,
+ mx_uint grid_dim_z, mx_uint block_dim_x,
+ mx_uint block_dim_y, mx_uint block_dim_z,
+ mx_uint shared_mem) {
+ API_BEGIN();
+#if MXNET_USE_CUDA
+ auto kernel = reinterpret_cast*>(handle);
+ const auto& signature = (*kernel)->signature();
+ std::vector any_args;
+ for (size_t i = 0; i < signature.size(); ++i) {
+ if (signature[i].is_ndarray) {
+ any_args.emplace_back(*static_cast(args[i]));
+ } else {
+ MSHADOW_TYPE_SWITCH(signature[i].dtype, DType, {
+ any_args.emplace_back(*static_cast(args[i]));
+ });
+ }
+ }
+ (*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y,
+ grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem);
+#else
+ LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+ API_END();
+}
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 0f63895d..c135ff8a 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -229,6 +229,40 @@ inline DType __device__ CudaMin(DType a, DType b) {
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
}
+/*!
+ * \brief Protected NVRTC call.
+ * \param func Expression to call.
+ *
+ * It checks for NVRTC errors after invocation of the expression.
+ */
+#define NVRTC_CALL(x) \
+ { \
+ nvrtcResult result = x; \
+ CHECK_EQ(result, NVRTC_SUCCESS) \
+ << #x " failed with error " \
+ << nvrtcGetErrorString(result); \
+ }
+
+/*!
+ * \brief Protected CUDA driver call.
+ * \param func Expression to call.
+ *
+ * It checks for CUDA driver errors after invocation of the expression.
+ */
+#define CUDA_DRIVER_CALL(func) \
+ { \
+ CUresult e = (func); \
+ if (e != CUDA_SUCCESS) { \
+ char const * err_msg = nullptr; \
+ if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
+ LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
+ } else { \
+ LOG(FATAL) << "CUDA Driver: " << err_msg; \
+ } \
+ } \
+ }
+
+
#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
diff --git a/src/common/mxrtc.cc b/src/common/mxrtc.cc
deleted file mode 100644
index e72ac0ba..00000000
--- a/src/common/mxrtc.cc
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mxrtc.cc
- * \brief Wrapper for NVRTC
- * \author Junyuan Xie
- */
-#include
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-namespace mxnet {
-const char MXRtc::str_type[] = "float";
-std::unordered_map MXRtc::kernel_registry;
-
-MXRtc::MXRtc(const std::string& name,
- std::vector > const& input,
- std::vector > const& output,
- const std::string& kernel) {
- name_ = name;
- num_input_ = input.size();
- num_output_ = output.size();
- code_ = decorate(name, input, output, kernel);
- if (MXRtc::kernel_registry.find(code_) != MXRtc::kernel_registry.end()) {
- ptx_ = MXRtc::kernel_registry[code_];
- } else {
- ptx_ = compile(name, code_);
- }
-}
-
-void MXRtc::push(std::vector const& input,
- std::vector const& output,
- unsigned int grid_dim_X,
- unsigned int grid_dim_Y,
- unsigned int grid_dim_Z,
- unsigned int block_dim_X,
- unsigned int block_dim_Y,
- unsigned int block_dim_Z) {
- CHECK_EQ(num_input_, input.size());
- CHECK_EQ(num_output_, output.size());
- CHECK(output.size());
- cudaError_enum err;
- CUfunction func;
- int dev_id = output[0].ctx().dev_id;
- if (func_.find(dev_id) != func_.end()) {
- func = func_[dev_id];
- } else {
- CUmodule module;
- CHECK_EQ(err = cuModuleLoadDataEx(&module, ptx_, 0, 0, 0), CUDA_SUCCESS)
- << "CudaError: " << err;
- CHECK_EQ(err = cuModuleGetFunction(&func, module, name_.c_str()), CUDA_SUCCESS)
- << "CudaError: " << err;
- module_[dev_id] = module;
- func_[dev_id] = func;
- }
- auto op = [this, func, input, output,
- grid_dim_X, grid_dim_Y, grid_dim_Z,
- block_dim_X, block_dim_Y, block_dim_Z](RunContext rctx) {
- std::vector float_args;
- for (auto& i : input) float_args.push_back(static_cast(i.data().dptr_));
- for (auto& i : output) float_args.push_back(static_cast(i.data().dptr_));
- std::vector args;
- for (auto& i : float_args) args.push_back(&i);
- cudaError_enum err;
- cudaError_t cuerr;
- CHECK_EQ(err = cuLaunchKernel(func,
- grid_dim_X, grid_dim_Y, grid_dim_Z,
- block_dim_X, block_dim_Y, block_dim_Z,
- 0, rctx.get_stream()->stream_,
- args.data(), 0), CUDA_SUCCESS) << "CudaError: " << err;
- CHECK_EQ(cuerr = cudaStreamSynchronize(rctx.get_stream()->stream_),
- cudaSuccess) << "CudaError: " << cuerr;
- };
- std::vector var_in, var_out;
- for (auto& i : input) var_in.push_back(i.var());
- for (auto& i : output) var_out.push_back(i.var());
- Engine::Get()->PushSync(op, output[0].ctx(), var_in, var_out,
- FnProperty::kNormal, 0, PROFILER_MESSAGE("MXRtc"));
-}
-
-std::string MXRtc::decorate(const std::string& name,
- std::vector > const& input,
- std::vector > const& output,
- const std::string kernel) {
- std::string source;
- source = source + "\nextern \"C\" __global__ void " + name + "(";
- for (auto &i : input) {
- source = source + "const " + str_type + "* " + i.first + ",";
- }
- for (auto &i : output) {
- source = source + str_type + "* " + i.first + ",";
- }
- source.pop_back();
- source = source + ") {\n";
- for (auto &i : input) {
- source = source + "const int " + i.first + "_ndim = " +
- std::to_string(i.second.shape().ndim()) + ";\n";
- source = source + "const int " + i.first + "_dims[] = {";
- for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
- source = source + std::to_string(i.second.shape()[j]) + ",";
- }
- source.pop_back();
- source = source + "};\n";
- }
- for (auto &i : output) {
- source = source + "const int " + i.first + "_ndim = " +
- std::to_string(i.second.shape().ndim()) + ";\n";
- source = source + "const int " + i.first + "_dims[] = {";
- for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
- source = source + std::to_string(i.second.shape()[j]) + ",";
- }
- source.pop_back();
- source = source + "};\n";
- }
- source = source + kernel + "\n}\n";
- return source;
-}
-
-char* MXRtc::compile(const std::string& name, const std::string& code) {
- nvrtcProgram prog;
- CHECK_EQ(nvrtcCreateProgram(&prog,
- code.c_str(),
- (name+".cu").c_str(),
- 0,
- NULL,
- NULL), NVRTC_SUCCESS);
- nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, NULL);
- size_t log_size;
- CHECK_EQ(nvrtcGetProgramLogSize(prog, &log_size), NVRTC_SUCCESS);
- char *log = new char[log_size];
- CHECK_EQ(nvrtcGetProgramLog(prog, log), NVRTC_SUCCESS);
- CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
-
- size_t ptx_size;
- CHECK_EQ(nvrtcGetPTXSize(prog, &ptx_size), NVRTC_SUCCESS);
- char *ptx = new char[ptx_size];
- CHECK_EQ(nvrtcGetPTX(prog, ptx), NVRTC_SUCCESS);
- CHECK_EQ(nvrtcDestroyProgram(&prog), NVRTC_SUCCESS);
- return ptx;
-}
-
-} // namespace mxnet
-
-#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
diff --git a/src/common/rtc.cc b/src/common/rtc.cc
new file mode 100644
index 00000000..cd26f0e0
--- /dev/null
+++ b/src/common/rtc.cc
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include
+#include
+
+#include "../common/cuda_utils.h"
+#include "../operator/operator_common.h"
+
+#if MXNET_USE_CUDA
+
+namespace mxnet {
+namespace rtc {
+
+CudaModule::Chunk::Chunk(
+ const char* source,
+ const std::vector& options,
+ const std::vector& exports) {
+ NVRTC_CALL(nvrtcCreateProgram(&prog_, source, "source.cu", 0, NULL, NULL));
+ for (const auto& i : exports) exports_.insert(i);
+#if CUDA_VERSION >= 8000
+ for (const auto& func : exports) {
+ NVRTC_CALL(nvrtcAddNameExpression(prog_, func.c_str()));
+ }
+#else
+ CHECK_EQ(exports.size(), 0)
+ << "Exporting is only supported with CUDA 8.0 and above. "
+ << "For lower version of CUDA, please prepend your kernel defintiions "
+ << "with extern \"C\" instead.";
+#endif
+ std::vector c_options;
+ for (const auto& i : options) c_options.push_back(i.c_str());
+ nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data());
+ if (compile_res != NVRTC_SUCCESS) {
+ size_t err_size;
+ NVRTC_CALL(nvrtcGetProgramLogSize(prog_, &err_size));
+ std::vector err(err_size);
+ NVRTC_CALL(nvrtcGetProgramLog(prog_, err.data()));
+ LOG(FATAL) << err.data();
+ }
+
+ size_t ptx_size;
+ NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size));
+ ptx_ = new char[ptx_size];
+ NVRTC_CALL(nvrtcGetPTX(prog_, ptx_));
+}
+
+
+CudaModule::Chunk::~Chunk() {
+ for (const auto& kv : mod_) {
+ CUDA_DRIVER_CALL(cuModuleUnload(kv.second));
+ }
+ NVRTC_CALL(nvrtcDestroyProgram(&prog_));
+ delete ptx_;
+}
+
+
+CUfunction CudaModule::Chunk::GetFunction(
+ const std::string& mangled_name,
+ const Context& ctx) {
+ CHECK_EQ(ctx.dev_mask(), gpu::kDevMask)
+ << "CUDA Runtime compilation only supports Nvidia GPU.";
+ auto iter = mod_.find(ctx.dev_id);
+ CUmodule module;
+ if (iter != mod_.end()) {
+ module = iter->second;
+ } else {
+ CUDA_CALL(cudaSetDevice(ctx.dev_id));
+ CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
+ mod_[ctx.dev_id] = module;
+ }
+ CUfunction function;
+ auto err = cuModuleGetFunction(&function, module, mangled_name.c_str());
+ if (err == CUDA_ERROR_NOT_FOUND) {
+ LOG(FATAL) << "Cannot find cuda kernel with name '" << mangled_name
+ << "'. Please either prepend kernel definition "
+ << "with 'extern \"C\"' or add its name to exports "
+ << "when creating CudaModule.";
+ }
+ CUDA_DRIVER_CALL(err);
+ return function;
+}
+
+
+std::shared_ptr CudaModule::GetKernel(
+ const std::string& name, const std::vector& signature) {
+ std::string mangled_name = name;
+#if CUDA_VERSION >= 8000
+ if (ptr_->exports_.count(name)) {
+ const char * c_mangled_name;
+ NVRTC_CALL(nvrtcGetLoweredName(ptr_->prog_, name.c_str(), &c_mangled_name));
+ mangled_name = c_mangled_name;
+ }
+#endif
+ return std::shared_ptr(new Kernel(ptr_, mangled_name, signature));
+}
+
+
+CudaModule::Kernel::Kernel(
+ const std::shared_ptr& mod,
+ const std::string& mangled_name,
+ const std::vector& signature)
+ : mangled_name_(mangled_name), signature_(signature), mod_(mod) {
+}
+
+void CudaModule::Kernel::Launch(
+ const Context& ctx, const std::vector& args,
+ uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
+ uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
+ uint32_t shared_mem) {
+ CHECK_EQ(ctx.dev_mask(), gpu::kDevMask)
+ << "CUDA Runtime compilation only supports Nvidia GPU.";
+
+ auto mod = mod_;
+ auto arg_types = signature();
+
+ CUfunction function;
+ auto iter = func_.find(ctx.dev_id);
+ if (iter != func_.end()) {
+ function = iter->second;
+ } else {
+ function = mod_->GetFunction(mangled_name_, ctx);
+ func_[ctx.dev_id] = function;
+ }
+
+ std::vector read_vars, write_vars;
+ for (size_t i = 0; i < arg_types.size(); ++i) {
+ if (!arg_types[i].is_ndarray) continue;
+ const auto& array = dmlc::get(args[i]);
+ CHECK_EQ(array.dtype(), arg_types[i].dtype)
+ << "The i-th argument is expected to be an NDArray of "
+ << op::type_string(arg_types[i].dtype) << " type, but got "
+ << op::type_string(array.dtype()) << " instead.";
+ if (arg_types[i].is_const) {
+ read_vars.emplace_back(array.var());
+ } else {
+ write_vars.emplace_back(array.var());
+ }
+ }
+
+ Engine::Get()->PushSync(
+ [function, mod, args, arg_types, grid_dim_x, grid_dim_y, grid_dim_z,
+ block_dim_x, block_dim_y, block_dim_z, shared_mem](RunContext rctx) {
+ std::vector p_args;
+ for (size_t i = 0; i < arg_types.size(); ++i) {
+ if (arg_types[i].is_ndarray) {
+ const auto& array = dmlc::get(args[i]);
+ p_args.push_back(reinterpret_cast(const_cast(&array.data().dptr_)));
+ } else {
+ MSHADOW_TYPE_SWITCH(arg_types[i].dtype, DType, {
+ const auto& number = dmlc::get(args[i]);
+ p_args.push_back(const_cast(&number));
+ });
+ }
+ }
+
+ mshadow::Stream *s = rctx.get_stream();
+ CUDA_DRIVER_CALL(cuLaunchKernel(
+ function, grid_dim_x, grid_dim_y, grid_dim_z,
+ block_dim_x, block_dim_y, block_dim_z,
+ shared_mem, s->stream_,
+ p_args.data(), 0));
+ CUDA_CALL(cudaStreamSynchronize(s->stream_));
+ }, ctx, read_vars, write_vars, FnProperty::kNormal, 0,
+ PROFILER_MESSAGE(mangled_name_.c_str()));
+}
+
+
+} // namespace rtc
+} // namespace mxnet
+
+#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 745974fb..ec658448 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1390,6 +1390,35 @@ def test_gluon_ctc_consistency():
assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3)
+def test_cuda_rtc():
+ source = r'''
+ extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
+ y[i] += alpha * x[i];
+ }
+
+ extern "C" __global__ void saxpy(const float *x, float *y, float alpha) {
+ extern __shared__ float smem[];
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
+ smem[threadIdx.x] = x[i];
+ y[i] += alpha * smem[threadIdx.x];
+ }
+ '''
+ module = mx.rtc.CudaModule(source)
+ axpy = module.get_kernel("axpy", "const float *x, float *y, float alpha")
+ x = mx.nd.ones((10,), ctx=mx.gpu(0))
+ y = mx.nd.zeros((10,), ctx=mx.gpu(0))
+ axpy.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+ assert (y.asnumpy() == 3).all()
+
+ saxpy = module.get_kernel("saxpy", "const float *x, float *y, float alpha")
+ saxpy.launch([x, y, 4.0], mx.gpu(0), (1, 1, 1), (10, 1, 1), 10)
+ assert (y.asnumpy() == 7).all()
+
+ saxpy.launch([x, y, 5.0], mx.gpu(0), (2, 1, 1), (5, 1, 1), 5)
+ assert (y.asnumpy() == 12).all()
+
+
if __name__ == '__main__':
import nose
nose.runmodule()