Skip to content

Commit b3a08e9

Browse files
committed
cuda 6.5 compatibility fix
1 parent 794232c commit b3a08e9

File tree

6 files changed

+26
-17
lines changed

6 files changed

+26
-17
lines changed

Makefile

+7-1
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,15 @@ LIB_DEP += $(DMLC_CORE)/libdmlc.a
114114
ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(LIB_DEP)
115115
ifeq ($(USE_CUDA), 1)
116116
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ)
117-
LDFLAGS += -lnvrtc -lcuda
117+
LDFLAGS += -lcuda
118118
endif
119119

120+
ifeq ($(USE_NVRTC), 1)
121+
LDFLAGS += -lnvrtc
122+
CFLAGS += -DMXNET_USE_NVRTC=1
123+
else
124+
CFLAGS += -DMXNET_USE_NVRTC=0
125+
endif
120126

121127

122128
build/%.o: src/%.cc

include/mxnet/mxrtc.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
#ifndef MXNET_MXRTC_H_
88
#define MXNET_MXRTC_H_
99
#include "./base.h"
10-
#if MXNET_USE_CUDA
11-
10+
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1211
#include <nvrtc.h>
1312
#include <cuda.h>
1413

@@ -88,5 +87,5 @@ class MXRtc {
8887

8988
} // namespace mxnet
9089

91-
#endif // MXNET_USE_CUDA
90+
#endif // MXNET_USE_CUDA && MXNET_USE_NVRTC
9291
#endif // MXNET_MXRTC_H_

make/config.mk

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE
4848
# whether use CUDNN R3 library
4949
USE_CUDNN = 0
5050

51+
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
52+
USE_NVRTC = 0
53+
5154
# whether use opencv during compilation
5255
# you can disable it, however, you will not able to use
5356
# imbin iterator

make/osx.mk

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE
4848
# whether use CUDNN R3 library
4949
USE_CUDNN = 0
5050

51+
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
52+
USE_NVRTC = 0
53+
5154
# whether use opencv during compilation
5255
# you can disable it, however, you will not able to use
5356
# imbin iterator

src/c_api/c_api.cc

+9-9
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
11541154
NDArrayHandle* inputs, NDArrayHandle* outputs,
11551155
char* kernel, RtcHandle *out) {
11561156
API_BEGIN();
1157-
#if MXNET_USE_CUDA
1157+
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
11581158
std::vector<std::pair<std::string, NDArray> > input, output;
11591159
for (mx_uint i = 0; i < num_input; ++i) {
11601160
input.push_back(std::pair<std::string, NDArray>(input_names[i],
@@ -1167,8 +1167,8 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
11671167
MXRtc *rtc = new MXRtc(name, input, output, kernel);
11681168
*out = reinterpret_cast<RtcHandle>(rtc);
11691169
#else
1170-
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
1171-
#endif // MXNET_USE_CUDA
1170+
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
1171+
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
11721172
API_END();
11731173
}
11741174

@@ -1181,7 +1181,7 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
11811181
mx_uint blockDimY,
11821182
mx_uint blockDimZ) {
11831183
API_BEGIN();
1184-
#if MXNET_USE_CUDA
1184+
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
11851185
std::vector<NDArray> input, output;
11861186
for (mx_uint i = 0; i < num_input; ++i) {
11871187
input.push_back(*reinterpret_cast<NDArray*>(inputs[i]));
@@ -1197,18 +1197,18 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
11971197
blockDimY,
11981198
blockDimZ);
11991199
#else
1200-
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
1201-
#endif // MXNET_USE_CUDA
1200+
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
1201+
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
12021202
API_END();
12031203
}
12041204

12051205
int MXRtcFree(RtcHandle handle) {
12061206
API_BEGIN();
1207-
#if MXNET_USE_CUDA
1207+
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
12081208
delete reinterpret_cast<MXRtc*>(handle);
12091209
#else
1210-
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
1211-
#endif // MXNET_USE_CUDA
1210+
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
1211+
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
12121212
API_END();
12131213
}
12141214

src/common/mxrtc.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
* \author Junyuan Xie
66
*/
77
#include <mxnet/mxrtc.h>
8-
#if MXNET_USE_CUDA
9-
8+
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
109
namespace mxnet {
11-
1210
const std::string MXRtc::str_type = "float";
1311
std::unordered_map<std::string, char*> MXRtc::kernel_registry;
1412

@@ -139,4 +137,4 @@ char* MXRtc::compile(const std::string& name, const std::string& code) {
139137

140138
} // namespace mxnet
141139

142-
#endif // MXNET_USE_CUDA
140+
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))

0 commit comments

Comments
 (0)