Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Add backend MXGetMaxSupportedArch() and frontend get_rtc_co…
Browse files Browse the repository at this point in the history
…mpile_opts() for CUDA enhanced compatibility (#20443)

* Add backend MXGetMaxSupportedArch() and frontend get_rtc_compile_opts()

* Fix rtc options vector handling

* Fix get_cuda_compute_capability(ctx) on Windows
  • Loading branch information
DickJC123 authored Jul 13, 2021
1 parent 5bd9756 commit 8fd17ce
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 28 deletions.
7 changes: 7 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ MXNET_DLL int MXGetEnv(const char* name,
MXNET_DLL int MXSetEnv(const char* name,
const char* value);

/*!
* \brief Get the maximum SM architecture supported by the nvrtc compiler
* \param max_arch The maximum supported architecture (e.g. would be 80, if Ampere)
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXGetMaxSupportedArch(uint32_t *max_arch);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import
from .util import get_max_supported_compute_capability, get_rtc_compile_opts # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability

Expand Down
26 changes: 25 additions & 1 deletion python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def get_cuda_compute_capability(ctx):
raise ValueError('Expecting a gpu context to get cuda compute capability, '
'while received ctx {}'.format(str(ctx)))

libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll', 'cuda.dll')
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
Expand Down Expand Up @@ -1176,3 +1176,27 @@ def setenv(name, value):
"""
passed_value = None if value is None else c_str(value)
check_call(_LIB.MXSetEnv(c_str(name), passed_value))


def get_max_supported_compute_capability():
"""Get the maximum compute capability (SM arch) supported by the nvrtc compiler
"""
max_supported_cc = ctypes.c_int()
check_call(_LIB.MXGetMaxSupportedArch(ctypes.byref(max_supported_cc)))
return max_supported_cc.value


def get_rtc_compile_opts(ctx):
"""Get the compile ops suitable for the context, given the toolkit/driver config
"""
device_cc = get_cuda_compute_capability(ctx)
max_supported_cc = get_max_supported_compute_capability()

# CUDA toolkits starting with 11.1 (first to support arch 86) can compile directly to SASS
can_compile_to_SASS = max_supported_cc >= 86
should_compile_to_SASS = can_compile_to_SASS and \
device_cc <= max_supported_cc
device_cc_as_used = min(device_cc, max_supported_cc)
arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute",
device_cc_as_used)
return [arch_opt]
11 changes: 11 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <nnvm/pass.h>
#include "./c_api_common.h"
#include "../operator/subgraph/subgraph_property.h"
#include "../common/cuda/rtc.h"

int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
Expand Down Expand Up @@ -128,3 +129,13 @@ int MXSetEnv(const char* name,
#endif
API_END();
}

int MXGetMaxSupportedArch(uint32_t *max_arch) {
API_BEGIN();
#if MXNET_USE_CUDA
*max_arch = static_cast<uint32_t>(mxnet::common::cuda::rtc::GetMaxSupportedArch());
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
API_END();
}
50 changes: 32 additions & 18 deletions src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@ std::string to_string(OpReqType req) {

} // namespace util

int GetMaxSupportedArch() {
#if CUDA_VERSION < 10000
constexpr int max_supported_sm_arch = 72;
#elif CUDA_VERSION < 11000
constexpr int max_supported_sm_arch = 75;
#elif CUDA_VERSION < 11010
constexpr int max_supported_sm_arch = 80;
#elif CUDA_VERSION < 11020
constexpr int max_supported_sm_arch = 86;
#else
// starting with cuda 11.2, nvrtc can report the max supported arch,
// removing the need to update this routine with each new cuda version.
static int max_supported_sm_arch = []() {
int num_archs = 0;
NVRTC_CALL(nvrtcGetNumSupportedArchs(&num_archs));
std::vector<int> archs(num_archs);
if (num_archs > 0) {
NVRTC_CALL(nvrtcGetSupportedArchs(archs.data()));
} else {
LOG(FATAL) << "Could not determine supported cuda archs.";
}
return archs[num_archs - 1];
}();
#endif
return max_supported_sm_arch;
}

namespace {

// Obtain compilation log from the program.
Expand Down Expand Up @@ -97,27 +124,14 @@ std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) {
}

std::tuple<bool, std::string> GetArchString(const int sm_arch) {
#if CUDA_VERSION < 10000
constexpr int max_supported_sm_arch = 72;
#elif CUDA_VERSION < 11000
constexpr int max_supported_sm_arch = 75;
#elif CUDA_VERSION < 11010
constexpr int max_supported_sm_arch = 80;
#else
constexpr int max_supported_sm_arch = 86;
#endif

#if CUDA_VERSION <= 11000
const int sm_arch_as_used = std::min(sm_arch, GetMaxSupportedArch());
// Always use PTX for CUDA <= 11.0
const bool known_arch = false;
#else
const bool known_arch = sm_arch <= max_supported_sm_arch;
#endif
const int actual_sm_arch = std::min(sm_arch, max_supported_sm_arch);
const bool known_arch = (CUDA_VERSION > 11000) &&
(sm_arch == sm_arch_as_used);
if (known_arch) {
return {known_arch, "sm_" + std::to_string(actual_sm_arch)};
return {known_arch, "sm_" + std::to_string(sm_arch_as_used)};
} else {
return {known_arch, "compute_" + std::to_string(actual_sm_arch)};
return {known_arch, "compute_" + std::to_string(sm_arch_as_used)};
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/common/cuda/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ std::string to_string(OpReqType req);

} // namespace util

int GetMaxSupportedArch();

extern std::mutex lock;

/*! \brief Compile and get the GPU kernel. Uses cache in order to
Expand Down
4 changes: 2 additions & 2 deletions src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ CudaModule::Chunk::Chunk(
<< "For lower version of CUDA, please prepend your kernel defintiions "
<< "with extern \"C\" instead.";
#endif
std::vector<const char*> c_options(options.size());
for (const auto& i : options) c_options.emplace_back(i.c_str());
std::vector<const char*> c_options;
for (const auto& i : options) c_options.push_back(i.c_str());
nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data());
if (compile_res != NVRTC_SUCCESS) {
size_t err_size;
Expand Down
17 changes: 10 additions & 7 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import mxnet.ndarray.sparse as mxsps
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose
from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, discard_stderr
from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment
from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment, get_rtc_compile_opts
from mxnet.base import MXNetError
from mxnet import autograd

Expand Down Expand Up @@ -1796,6 +1796,7 @@ def test_autograd_save_memory():

@pytest.mark.serial
def test_cuda_rtc():
ctx = mx.gpu(0)
source = r'''
extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
Expand All @@ -1809,18 +1810,20 @@ def test_cuda_rtc():
y[i] += alpha * smem[threadIdx.x];
}
'''
module = mx.rtc.CudaModule(source)

compile_opts = get_rtc_compile_opts(ctx)
module = mx.rtc.CudaModule(source, options=compile_opts)
axpy = module.get_kernel("axpy", "const float *x, float *y, float alpha")
x = mx.nd.ones((10,), ctx=mx.gpu(0))
y = mx.nd.zeros((10,), ctx=mx.gpu(0))
axpy.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
x = mx.nd.ones((10,), ctx=ctx)
y = mx.nd.zeros((10,), ctx=ctx)
axpy.launch([x, y, 3.0], ctx, (1, 1, 1), (10, 1, 1))
assert (y.asnumpy() == 3).all()

saxpy = module.get_kernel("saxpy", "const float *x, float *y, float alpha")
saxpy.launch([x, y, 4.0], mx.gpu(0), (1, 1, 1), (10, 1, 1), 10)
saxpy.launch([x, y, 4.0], ctx, (1, 1, 1), (10, 1, 1), 10)
assert (y.asnumpy() == 7).all()

saxpy.launch([x, y, 5.0], mx.gpu(0), (2, 1, 1), (5, 1, 1), 5)
saxpy.launch([x, y, 5.0], ctx, (2, 1, 1), (5, 1, 1), 5)
assert (y.asnumpy() == 12).all()


Expand Down

0 comments on commit 8fd17ce

Please sign in to comment.