Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Expose the number of GPUs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdomhan committed Apr 10, 2018
1 parent 73273cf commit 08270f0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ struct Context {
* \return GPU Context. -1 for current GPU.
*/
inline static Context GPU(int32_t dev_id = -1);
/*!
* Get the number of GPUs available.
* \return The number of GPUs that are available.
*/
inline static int32_t GetGPUCount();
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
Expand Down Expand Up @@ -316,6 +321,17 @@ inline Context Context::GPU(int32_t dev_id) {
return Create(kGPU, dev_id);
}

inline int32_t Context::GetGPUCount() {
#if MXNET_USE_CUDA
int32_t count;
cudaError_t e = cudaGetDeviceCount(&count);
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
return count;
#else
LOG(FATAL) << "Please compile with CUDA support to query the number of GPUs.";
#endif
}

inline Context Context::FromString(const std::string& str) {
Context ret;
try {
Expand Down
7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,13 @@ MXNET_DLL int MXSetNumOMPThreads(int thread_num);
*/
MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);

/*!
* \brief Get the number of GPUs.
* \param pointer to int that will hold the number of GPUs available.
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXGetGPUCount(int* out);

/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
Expand Down
23 changes: 23 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
"""Context management API of mxnet."""
from __future__ import absolute_import

import ctypes
from .base import _LIB
from .base import check_call

class Context(object):
"""Constructs a context.
Expand Down Expand Up @@ -212,6 +216,25 @@ def gpu(device_id=0):
return Context('gpu', device_id)


def num_gpus():
"""Query CUDA for the number of GPUs present.
Raises
------
Will raise an exception on any CUDA error or in case MXNet was not
compiled with CUDA support.
Returns
-------
count : int
The number of GPUs.
"""
count = ctypes.c_int()
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value

def current_context():
"""Returns the current context.
Expand Down
6 changes: 6 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
API_END();
}

int MXGetGPUCount(int* out) {
API_BEGIN();
*out = Context::GetGPUCount();
API_END();
}

int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
Expand Down

0 comments on commit 08270f0

Please sign in to comment.