|
3 | 3 | from subprocess import check_output
|
4 | 4 | import re
|
5 | 5 | import docker
|
| 6 | +from pynvml import * |
6 | 7 |
|
7 | 8 | class NVDockerClient:
|
8 | 9 |
|
| 10 | + nvml_initialized = False |
| 11 | + |
9 | 12 | def __init__(self):
|
10 | 13 | self.docker_client = docker.from_env(version="auto")
|
| 14 | + NVDockerClient.__check_nvml_init() |
| 15 | + |
| 16 | + """ |
| 17 | + Private method to check if nvml is loaded (and load the library if it isn't loaded) |
| 18 | + """ |
| 19 | + def __check_nvml_init(): |
| 20 | + if not NVDockerClient.nvml_initialized: |
| 21 | + nvmlInit() |
| 22 | + print("NVIDIA Driver Version:", nvmlSystemGetDriverVersion()) |
| 23 | + NVDockerClient.nvml_initialized = True |
11 | 24 |
|
12 | 25 | #TODO: Testing on MultiGPU
|
13 | 26 | def create_container(self, image, **kwargs):
|
@@ -152,22 +165,13 @@ def exec_run(self, cid, cmd):
|
152 | 165 |
|
153 | 166 | @staticmethod
|
154 | 167 | def gpu_info():
|
155 |
| - #output = check_output(["nvidia-smi", "-L"]).decode("utf-8") |
156 |
| - keys = ['memory_free', 'memory_used', 'memory_total'] |
157 |
| - query_gpu = check_output(["nvidia-smi", "--query-gpu=memory.free,memory.used,memory.total","--format=csv,noheader"]).decode("utf-8") |
158 |
| - #regex = re.compile(r"GPU (?P<id>\d+):") |
159 |
| - query_gpu = query_gpu.strip() |
| 168 | + NVDockerClient.__check_nvml_init() |
160 | 169 | gpus = {}
|
161 |
| - id = 0 |
162 |
| - for gpu in query_gpu.split("\n"): |
163 |
| - gpu_info = {} |
164 |
| - key_id = 0 |
165 |
| - for info in gpu.split(","): |
166 |
| - info = info.strip() |
167 |
| - gpu_info[keys[key_id]] = info.split(" ")[0]; |
168 |
| - key_id += 1 |
169 |
| - gpus[id] = gpu_info; |
170 |
| - id += 1 |
| 170 | + num_gpus = nvmlDeviceGetCount() |
| 171 | + for i in range(num_gpus): |
| 172 | + gpu_handle = nvmlDeviceGetHandleByIndex(i) |
| 173 | + gpu_name = nvmlDeviceGetName(gpu_handle) |
| 174 | + gpus[i] = {"gpu_handle": gpu_handle, "gpu_name": gpu_name} |
171 | 175 | return gpus
|
172 | 176 |
|
173 | 177 | @staticmethod
|
|
0 commit comments