Skip to content

Commit

Permalink
Expose the number of GPUs. (apache#10354)
Browse files Browse the repository at this point in the history
* Expose the number of GPUs.

* Added GPU test.

* Removed trailing whitespace.

* making the compiler happy

* Reverted CPU only logic and added CPU test.

* Updated python docs.

* Removing break from test.

* no longer assert on 0 gpus
  • Loading branch information
tdomhan authored and piiswrong committed May 15, 2018
1 parent d404f3f commit f3c6bbe
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ struct Context {
* \return GPU Context. -1 for current GPU.
*/
inline static Context GPU(int32_t dev_id = -1);
/*!
* Get the number of GPUs available.
* \return The number of GPUs that are available.
*/
inline static int32_t GetGPUCount();
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -307,6 +312,20 @@ inline Context Context::GPU(int32_t dev_id) {
return Create(kGPU, dev_id);
}

inline int32_t Context::GetGPUCount() {
#if MXNET_USE_CUDA
int32_t count;
cudaError_t e = cudaGetDeviceCount(&count);
if (e == cudaErrorNoDevice) {
return 0;
}
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
return count;
#else
return 0;
#endif
}

inline Context Context::FromString(const std::string& str) {
Context ret;
try {
Expand Down
7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ MXNET_DLL int MXSetNumOMPThreads(int thread_num);
*/
MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);

/*!
* \brief Get the number of GPUs.
* \param pointer to int that will hold the number of GPUs available.
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXGetGPUCount(int* out);

/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
Expand Down
21 changes: 21 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from __future__ import absolute_import
import threading
import warnings
import ctypes
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import _LIB
from .base import check_call


class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
"""Constructs a context.
Expand Down Expand Up @@ -237,6 +241,23 @@ def gpu(device_id=0):
return Context('gpu', device_id)


def num_gpus():
"""Query CUDA for the number of GPUs present.
Raises
------
Will raise an exception on any CUDA error.
Returns
-------
count : int
The number of GPUs.
"""
count = ctypes.c_int()
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value

def current_context():
"""Returns the current context.
Expand Down
6 changes: 6 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
API_END();
}

int MXGetGPUCount(int* out) {
API_BEGIN();
*out = Context::GetGPUCount();
API_END();
}

int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
Expand Down
3 changes: 3 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,9 @@ def test_softmax_activation():
assert_almost_equal(cpu_a.grad.asnumpy(), gpu_a.grad.asnumpy(),
atol = 1e-3, rtol = 1e-3)

def test_context_num_gpus():
# Test that num_gpus reports at least one GPU, as the test is run on a GPU host.
assert mx.context.num_gpus() > 0

if __name__ == '__main__':
import nose
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6007,6 +6007,18 @@ def test_activation():
name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)


def test_context_num_gpus():
try:
# Note: the test is run both on GPU and CPU hosts, so that we can not assert
# on a specific number here.
assert mx.context.num_gpus() >= 0
except mx.MXNetError as e:
# Note: On a CPU only host CUDA sometimes is not able to determine the number
# of GPUs
if str(e).find("CUDA") == -1:
raise e


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit f3c6bbe

Please sign in to comment.