Skip to content

Commit 10626d4

Browse files
committed
Add GPU monitor
1 parent cf41112 commit 10626d4

13 files changed

+440
-25
lines changed

axlearn/cloud/gcp/monitoring/tpu_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838

3939
# Interface names for libtpu metrics.
40+
# Reference:
41+
# https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/blob/7d2b2921fc9393a3dec7be5440e25132c217549b/tpu_info/tpu_info/metrics.py#L29
4042
class MetricName(enum.Enum):
4143
"""Metric names defined in libtpu."""
4244

@@ -131,7 +133,7 @@ def get_chip_metrics(
131133
def sorted_metric_response(
132134
metric_name: str,
133135
) -> list[tpu_metrics.Metric]:
134-
# Manually annotate type until GRPC supports annotations
136+
# Manually annotate type until GRPC supports annotations.
135137
# See https://github.com/grpc/grpc/issues/29041
136138
resp: tpu_metrics.MetricResponse = client.GetRuntimeMetric(
137139
tpu_metrics.MetricRequest(metric_name=metric_name)
@@ -153,7 +155,7 @@ def sorted_metric_response(
153155
metric_results[i].hbm_memory_usage_bytes = metric.gauge.as_int
154156
elif metric_name == MetricName.TENSORCORE_DUTY_CYCLE_PERCENT:
155157
for i, metric in enumerate(metric_result):
156-
metric_results[i].tensorcore_duty_cycle_percent = metric.gauge.as_double
158+
metric_results[i].device_duty_cycle_percent = metric.gauge.as_double
157159

158160
return metric_results
159161

@@ -227,9 +229,9 @@ def get_chip_metrics_v2(
227229
elif family.name == MetricV2Name.HBM_MEMORY_USAGE_BYTES.value:
228230
metric_results[i].hbm_memory_usage_bytes = metric[2]
229231
elif family.name == MetricV2Name.TENSORCORE_DUTY_CYCLE_PERCENT.value:
230-
metric_results[i].tensorcore_duty_cycle_percent = metric[2]
232+
metric_results[i].device_duty_cycle_percent = metric[2]
231233
elif family.name == MetricV2Name.TENSORCORE_UTILIZATION.value:
232-
metric_results[i].tensorcore_utilization = metric[2]
234+
metric_results[i].device_utilization = metric[2]
233235
elif family.name == MetricV2Name.HBM_MEMORY_BANDWIDTH_UTILIZATION.value:
234236
metric_results[i].hbm_memory_bandwidth_utilization = metric[2]
235237

axlearn/cloud/gcp/monitoring/tpu_client_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def _(self, val: float):
8080
"""Create Gauge from float."""
8181
return tpu_metrics.Gauge(as_double=val)
8282

83-
def GetRuntimeMetric(self, request: tpu_metrics.MetricRequest, context):
83+
def GetRuntimeMetric(
84+
self, request: tpu_metrics.MetricRequest, context
85+
): # pylint: disable=unused-argument
8486
"""Get the metric from the fake libtpu server."""
8587
metric_name = tpu_client.MetricName(request.metric_name)
8688
resp = self._responses[metric_name]
@@ -100,7 +102,9 @@ def GetRuntimeMetric(self, request: tpu_metrics.MetricRequest, context):
100102
)
101103
)
102104

103-
def ListSupportedMetrics(self, request: tpu_metrics.ListSupportedMetricsRequest, context):
105+
def ListSupportedMetrics(
106+
self, request: tpu_metrics.ListSupportedMetricsRequest, context
107+
): # pylint: disable=unused-argument
104108
"""List the supported metrics from the fake libtpu server."""
105109
# The test supported metrics are based on V5P libtpu.
106110
supported_metrics = [
@@ -210,7 +214,7 @@ def test_metrics(self, chip_type: device.TpuChip, responses):
210214
expected_usage = [
211215
tpu_client.Usage(
212216
device_id=i,
213-
tensorcore_duty_cycle_percent=d,
217+
device_duty_cycle_percent=d,
214218
hbm_memory_usage_bytes=m,
215219
hbm_memory_total_bytes=t,
216220
)
@@ -313,8 +317,8 @@ def test_all(self):
313317
expected_usage = [
314318
tpu_client.Usage(
315319
device_id=i,
316-
tensorcore_duty_cycle_percent=100.0,
317-
tensorcore_utilization=1.0 * (1 + i),
320+
device_duty_cycle_percent=100.0,
321+
device_utilization=1.0 * (1 + i),
318322
hbm_memory_total_bytes=int(1.02803439616e11),
319323
hbm_memory_usage_bytes=int(6.5e10),
320324
hbm_memory_bandwidth_utilization=30.0,

axlearn/cloud/gcp/monitoring/tpu_device_monitor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,7 @@ def collect_metrics(self) -> list[Usage]:
6868
def is_host_idle(self, usages: list[Usage]) -> bool:
6969
"""Check if the TPU device on the host are idle."""
7070
for usage in usages:
71-
if (
72-
usage.hbm_memory_bandwidth_utilization <= 0.1
73-
and usage.tensorcore_utilization <= 0.1
74-
):
71+
if usage.hbm_memory_bandwidth_utilization <= 0.1 and usage.device_utilization <= 0.1:
7572
logging.info("TPU device %d is idle.", usage.device_id)
7673
return True
7774
return False

axlearn/cloud/gcp/monitoring/tpu_device_monitor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_tpu_client(self):
2222
device_id=i,
2323
hbm_memory_total_bytes=int(1.02803439616e11),
2424
hbm_memory_usage_bytes=int(6.5e10),
25-
tensorcore_duty_cycle_percent=100.0,
26-
tensorcore_utilization=1.0 * (1 + i),
25+
device_duty_cycle_percent=100.0,
26+
device_utilization=1.0 * (1 + i),
2727
hbm_memory_bandwidth_utilization=30.0,
2828
)
2929
for i in range(4)

axlearn/common/launch_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
flags.DEFINE_enum(
5858
"device_monitor",
5959
"none",
60-
["none", "tpu"],
60+
["none", "tpu", "gpu"],
6161
"Whether to enable the device monitor. "
6262
"The device monitor collects the system metrics and logs them periodically. "
6363
"The device monitor also logs the idle status of the devices on the host, "
@@ -116,6 +116,11 @@ def get_trainer_config(
116116
from axlearn.cloud.gcp.monitoring.tpu_device_monitor import create_tpu_monitor
117117

118118
trainer_config.device_monitor = create_tpu_monitor()
119+
elif flag_values.device_monitor == "gpu":
120+
# pylint: disable-next=wrong-import-position,import-outside-toplevel
121+
from axlearn.common.monitoring.gpu_device_monitor import create_gpu_monitor
122+
123+
trainer_config.device_monitor = create_gpu_monitor()
119124
if hasattr(trainer_config.checkpointer, "trainer_dir"):
120125
# Set trainer_dir if not already set.
121126
if not isinstance(trainer_config.checkpointer.trainer_dir, str):

axlearn/common/monitoring/device_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _start_monitoring(self):
9494
if self.config.check_interval_in_sec > 0:
9595
self._monitor_stopping = threading.Event()
9696
self._monitor_thread = threading.Thread(
97-
name="tpu_device_monitor",
97+
name="device_monitor",
9898
target=self._monitor_loop,
9999
)
100100
self._monitor_thread.start()
@@ -115,4 +115,4 @@ def _monitor_loop(self):
115115
self._idle = self._check_host_and_log_metrics()
116116
if self._monitor_stopping.wait(timeout=self.config.check_interval_in_sec):
117117
break
118-
logging.info("mointor loop exit.")
118+
logging.info("monitor loop exit.")

axlearn/common/monitoring/device_monitor_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def is_host_idle(self, usages: list[Usage]) -> bool:
3434
# Make sure the usages are empty.
3535
return (
3636
usages[0].hbm_memory_bandwidth_utilization <= 0.1
37-
and usages[0].tensorcore_utilization <= 0.1
37+
and usages[0].device_utilization <= 0.1
3838
)
3939

4040

@@ -46,8 +46,8 @@ def test_client(self):
4646
fake_usage = [
4747
Usage(
4848
device_id=0,
49-
tensorcore_duty_cycle_percent=100.0,
50-
tensorcore_utilization=1.0,
49+
device_duty_cycle_percent=100.0,
50+
device_utilization=1.0,
5151
hbm_memory_total_bytes=100,
5252
hbm_memory_usage_bytes=50,
5353
hbm_memory_bandwidth_utilization=30.0,
@@ -71,8 +71,8 @@ def test_client_idle(self):
7171
fake_usage = [
7272
Usage(
7373
device_id=0,
74-
tensorcore_duty_cycle_percent=0.0,
75-
tensorcore_utilization=0.0,
74+
device_duty_cycle_percent=0.0,
75+
device_utilization=0.0,
7676
hbm_memory_total_bytes=100,
7777
hbm_memory_usage_bytes=50,
7878
hbm_memory_bandwidth_utilization=0.0,
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
"""Client for fetching GPU metrics via NVML."""
4+
import atexit
5+
6+
from absl import logging
7+
8+
9+
class NVMLMetrics:
10+
"""NVMLMetrics provides interfaces to fetch GPU utilization/memory metrics via NVML.
11+
12+
Calling `pynvml.nvmlInit` multiple times will lead to potential issues and it should only
13+
be called once.
14+
15+
And when the operations are completed, `pynvml.nvmlShutdown` should be called. Currently it is
16+
called using `atexit`.
17+
"""
18+
19+
nvml_initialized = False
20+
nvml = None
21+
22+
@classmethod
23+
def init_nvml(cls):
24+
"""It is not thread-safe. Please see the docstring of the class for more details.
25+
26+
Users should not call `init_nvml` multiple times.
27+
"""
28+
# pylint: disable-next=import-error,import-outside-toplevel
29+
import pynvml as nvml # pytype: disable=import-error
30+
31+
cls.nvml = nvml
32+
if not cls.nvml_initialized:
33+
try:
34+
nvml.nvmlInit()
35+
except:
36+
logging.exception("Failed to initialize NVML Library for GPU metrics monitoring.")
37+
raise
38+
else:
39+
cls.nvml_initialized = True
40+
atexit.register(nvml.nvmlShutdown)
41+
42+
@classmethod
43+
def get_gpu_device_count(cls):
44+
cls.init_nvml()
45+
46+
try:
47+
return cls.nvml.nvmlDeviceGetCount()
48+
except:
49+
logging.exception("Failed to get GPU device count.")
50+
raise
51+
52+
@classmethod
53+
def get_gpu_device_utilization(cls, device_id: int) -> float:
54+
cls.init_nvml()
55+
56+
# pylint: disable-next=import-error,import-outside-toplevel
57+
from pynvml import NVMLError # pytype: disable=import-error
58+
59+
try:
60+
device_handle = cls.nvml.nvmlDeviceGetHandleByIndex(device_id)
61+
62+
# Get all the utilization samples in the device buffer.
63+
# Typically this covers about 10-13 seconds of data.
64+
# Reference: https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html
65+
# Search for nvmlDeviceGetSamples.
66+
samples = cls.nvml.nvmlDeviceGetSamples(
67+
device_handle, cls.nvml.NVML_GPU_UTILIZATION_SAMPLES, 0
68+
)
69+
util_samples = [sample.sampleValue.uiVal for sample in samples[1]]
70+
if not util_samples:
71+
logging.warning("No samples returned from pynvml.")
72+
return 0
73+
average_utilization = sum(util_samples) / len(util_samples)
74+
return average_utilization
75+
except NVMLError as e:
76+
logging.exception("Failed to get GPU utilization metrics for device %d.", device_id)
77+
logging.exception(e)
78+
raise
79+
80+
@classmethod
81+
def get_gpu_device_memory(cls, device_id: int) -> tuple[float, float]:
82+
cls.init_nvml()
83+
84+
# pylint: disable-next=import-error,import-outside-toplevel
85+
from pynvml import NVMLError # pytype: disable=import-error
86+
87+
try:
88+
device_handle = cls.nvml.nvmlDeviceGetHandleByIndex(device_id)
89+
mem_info = cls.nvml.nvmlDeviceGetMemoryInfo(device_handle)
90+
91+
# Return tuple for memory usage, and total (in Bytes).
92+
return mem_info.used, mem_info.total
93+
except NVMLError as e:
94+
logging.exception("Failed to get GPU memory info for device %d.", device_id)
95+
logging.exception(e)
96+
raise
97+
98+
@classmethod
99+
def get_gpu_device_memory_utilization(cls, device_id: int) -> float:
100+
cls.init_nvml()
101+
102+
# pylint: disable-next=import-error,import-outside-toplevel
103+
from pynvml import NVMLError # pytype: disable=import-error
104+
105+
try:
106+
device_handle = cls.nvml.nvmlDeviceGetHandleByIndex(device_id)
107+
108+
# Get all the utilization samples in the device buffer.
109+
# Typically this covers about 10-13 seconds of data.
110+
samples = cls.nvml.nvmlDeviceGetSamples(
111+
device_handle, cls.nvml.NVML_MEMORY_UTILIZATION_SAMPLES, 0
112+
)
113+
util_samples = [sample.sampleValue.uiVal for sample in samples[1]]
114+
average_utilization = sum(util_samples) / len(util_samples)
115+
return average_utilization
116+
except NVMLError as e:
117+
logging.exception("Failed to get GPU utilization metrics for device %d.", device_id)
118+
logging.exception(e)
119+
raise

0 commit comments

Comments
 (0)