@@ -1154,7 +1154,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
1154
1154
NDArrayHandle* inputs, NDArrayHandle* outputs,
1155
1155
char * kernel, RtcHandle *out) {
1156
1156
API_BEGIN ();
1157
- #if MXNET_USE_CUDA
1157
+ #if (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1158
1158
std::vector<std::pair<std::string, NDArray> > input, output;
1159
1159
for (mx_uint i = 0 ; i < num_input; ++i) {
1160
1160
input.push_back (std::pair<std::string, NDArray>(input_names[i],
@@ -1167,8 +1167,8 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
1167
1167
MXRtc *rtc = new MXRtc (name, input, output, kernel);
1168
1168
*out = reinterpret_cast <RtcHandle>(rtc);
1169
1169
#else
1170
- LOG (FATAL) << " Need to compile with USE_CUDA=1 for MXRtc." ;
1171
- #endif // MXNET_USE_CUDA
1170
+ LOG (FATAL) << " Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc." ;
1171
+ #endif // (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1172
1172
API_END ();
1173
1173
}
1174
1174
@@ -1181,7 +1181,7 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
1181
1181
mx_uint blockDimY,
1182
1182
mx_uint blockDimZ) {
1183
1183
API_BEGIN ();
1184
- #if MXNET_USE_CUDA
1184
+ #if (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1185
1185
std::vector<NDArray> input, output;
1186
1186
for (mx_uint i = 0 ; i < num_input; ++i) {
1187
1187
input.push_back (*reinterpret_cast <NDArray*>(inputs[i]));
@@ -1197,18 +1197,18 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
1197
1197
blockDimY,
1198
1198
blockDimZ);
1199
1199
#else
1200
- LOG (FATAL) << " Need to compile with USE_CUDA=1 for MXRtc." ;
1201
- #endif // MXNET_USE_CUDA
1200
+ LOG (FATAL) << " Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc." ;
1201
+ #endif // (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1202
1202
API_END ();
1203
1203
}
1204
1204
1205
1205
int MXRtcFree (RtcHandle handle) {
1206
1206
API_BEGIN ();
1207
- #if MXNET_USE_CUDA
1207
+ #if (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1208
1208
delete reinterpret_cast <MXRtc*>(handle);
1209
1209
#else
1210
- LOG (FATAL) << " Need to compile with USE_CUDA=1 for MXRtc." ;
1211
- #endif // MXNET_USE_CUDA
1210
+ LOG (FATAL) << " Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc." ;
1211
+ #endif // (( MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
1212
1212
API_END ();
1213
1213
}
1214
1214
0 commit comments