diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 7cabfe5027ba..bff2ab45eca3 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -217,6 +217,11 @@ struct Context { * \return GPU Context. -1 for current GPU. */ inline static Context GPU(int32_t dev_id = -1); + /*! + * Get the number of GPUs available. + * \return The number of GPUs that are available. + */ + inline static int32_t GetGPUCount(); /*! * Create a pinned CPU context. * \param dev_id the device id for corresponding GPU. @@ -307,6 +312,20 @@ inline Context Context::GPU(int32_t dev_id) { return Create(kGPU, dev_id); } +inline int32_t Context::GetGPUCount() { +#if MXNET_USE_CUDA + int32_t count; + cudaError_t e = cudaGetDeviceCount(&count); + if (e == cudaErrorNoDevice) { + return 0; + } + CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); + return count; +#else + return 0; +#endif +} + inline Context Context::FromString(const std::string& str) { Context ret; try { diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9ac90d68c677..06e39bfeb38b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -383,6 +383,13 @@ MXNET_DLL int MXSetNumOMPThreads(int thread_num); */ MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); +/*! + * \brief Get the number of GPUs. + * \param pointer to int that will hold the number of GPUs available. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXGetGPUCount(int* out); + /*! * \brief get the MXNet library version as an integer * \param pointer to the integer holding the version number diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 5861890f40c1..61b70532dd74 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -20,7 +20,11 @@ from __future__ import absolute_import import threading import warnings +import ctypes from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass +from .base import _LIB +from .base import check_call + class Context(with_metaclass(_MXClassPropertyMetaClass, object)): """Constructs a context. @@ -237,6 +241,23 @@ def gpu(device_id=0): return Context('gpu', device_id) +def num_gpus(): + """Query CUDA for the number of GPUs present. + + Raises + ------ + Will raise an exception on any CUDA error. + + Returns + ------- + count : int + The number of GPUs. + + """ + count = ctypes.c_int() + check_call(_LIB.MXGetGPUCount(ctypes.byref(count))) + return count.value + def current_context(): """Returns the current context. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index b3dcd6a65d9d..467118b9921e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -116,6 +116,12 @@ int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) { API_END(); } +int MXGetGPUCount(int* out) { + API_BEGIN(); + *out = Context::GetGPUCount(); + API_END(); +} + int MXGetVersion(int *out) { API_BEGIN(); *out = static_cast(MXNET_VERSION); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 090773c77871..b9f2b6791d00 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1853,6 +1853,9 @@ def test_softmax_activation(): assert_almost_equal(cpu_a.grad.asnumpy(), gpu_a.grad.asnumpy(), atol = 1e-3, rtol = 1e-3) +def test_context_num_gpus(): + # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. + assert mx.context.num_gpus() > 0 if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0a6de8e7a1b8..e7976e01f9d8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6007,6 +6007,18 @@ def test_activation(): name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps) +def test_context_num_gpus(): + try: + # Note: the test is run both on GPU and CPU hosts, so that we can not assert + # on a specific number here. + assert mx.context.num_gpus() >= 0 + except mx.MXNetError as e: + # Note: On a CPU only host CUDA sometimes is not able to determine the number + # of GPUs + if str(e).find("CUDA") == -1: + raise e + + if __name__ == '__main__': import nose nose.runmodule()