Skip to content

Commit 05e872c

Browse files
committed
Avoid using heavy API to query single attribution
1 parent 95a323a commit 05e872c

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/codegen/opt/build_cuda_on.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,13 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
8484
std::vector<std::string> compile_params;
8585
std::vector<const char*> param_cstrings{};
8686
nvrtcProgram prog;
87-
cudaDeviceProp device_prop;
8887
std::string cc = "30";
89-
cudaError_t e = cudaGetDeviceProperties(&device_prop, 0);
88+
int major, minor;
89+
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
90+
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
9091

91-
if (e == cudaSuccess) {
92-
cc = std::to_string(device_prop.major) + std::to_string(device_prop.minor);
92+
if (e1 == cudaSuccess && e2 == cudaSuccess) {
93+
cc = std::to_string(major) + std::to_string(minor);
9394
} else {
9495
LOG(WARNING) << "cannot detect compute capability from your device, "
9596
<< "fall back to compute_30.";

src/runtime/cuda/cuda_device_api.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <dmlc/thread_local.h>
2828
#include <tvm/runtime/registry.h>
29+
#include <cuda.h>
2930
#include <cuda_runtime.h>
3031
#include "cuda_common.h"
3132

@@ -73,9 +74,9 @@ class CUDADeviceAPI final : public DeviceAPI {
7374
return;
7475
}
7576
case kDeviceName: {
76-
cudaDeviceProp props;
77-
CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id));
78-
*rv = std::string(props.name);
77+
std::string name(sizeof(cudaDeviceProp::name), 0);
78+
CUDA_DRIVER_CALL(cuDeviceGetName(&name[0], name.size(), ctx.device_id));
79+
*rv = std::string(name.c_str());
7980
return;
8081
}
8182
case kMaxClockRate: {

0 commit comments

Comments
 (0)