Skip to content

Commit 7a40f91

Browse files
Add Goodput & Badput recording and monitoring support. (apple#783)
* Code clean up * Add more testing * Fix docstrings * Remove recorder calls from trainer for now * Code cleanup gcp/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Code cleanup common/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Fix pre commit errors * Adding more tests * Further clean up * Fix a test error --------- Co-authored-by: Ruoming Pang <ruoming@gmail.com>
1 parent 031a7f3 commit 7a40f91

File tree

6 files changed

+164
-7
lines changed

6 files changed

+164
-7
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,39 @@
55
import jax
66
from absl import flags, logging
77
from ml_goodput_measurement import goodput
8+
from ml_goodput_measurement import monitoring as goodput_monitoring
89

910
from axlearn.cloud.common.utils import parse_kv_flags
1011
from axlearn.common import measurement
11-
from axlearn.common.config import maybe_set_config
12+
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config
1213

1314

1415
@measurement.register_recorder("goodput")
1516
class GoodputRecorder(measurement.Recorder):
1617
"""Records overall training goodput."""
1718

18-
Config = measurement.Recorder.Config
19+
@config_class
20+
class Config(measurement.Recorder.Config):
21+
"""Configures GoodputRecorder.
22+
23+
Attributes:
24+
upload_dir: Directory to store metrics for the monitor.
25+
upload_interval: Time interval (seconds) for monitoring uploads.
26+
"""
27+
28+
upload_dir: Required[str] = REQUIRED
29+
upload_interval: Required[int] = REQUIRED
1930

2031
@classmethod
2132
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
2233
"""Converts flags to a recorder.
2334
2435
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
25-
corresponding to keys will be set to the corresponding values.
36+
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
37+
additionally take in following Tensorboard configs in the recorder_spec:
38+
- upload_dir: The directory to write Tensorboard data to.
39+
- upload_interval: The time interval in seconds at which to query and upload data
40+
to Tensorboard.
2641
"""
2742
cfg: measurement.Recorder.Config = cls.default_config()
2843
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
@@ -32,6 +47,7 @@ def __init__(self, cfg):
3247
super().__init__(cfg)
3348
cfg: GoodputRecorder.Config = self.config
3449
self._recorder = None
50+
self._monitor = None
3551

3652
def record(self, event: measurement.Event, *args, **kwargs):
3753
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
@@ -49,10 +65,47 @@ def record(self, event: measurement.Event, *args, **kwargs):
4965
self._recorder.record_job_end_time(*args, **kwargs)
5066
elif event == measurement.Event.START_STEP:
5167
self._recorder.record_step_start_time(*args, **kwargs)
68+
elif event == measurement.Event.START_ACCELERATOR_INIT:
69+
self._recorder.record_tpu_init_start_time(*args, **kwargs)
70+
elif event == measurement.Event.END_ACCELERATOR_INIT:
71+
self._recorder.record_tpu_init_end_time(*args, **kwargs)
72+
elif event == measurement.Event.START_TRAINING_PREPARATION:
73+
self._recorder.record_training_preparation_start_time(*args, **kwargs)
74+
elif event == measurement.Event.END_TRAINING_PREPARATION:
75+
self._recorder.record_training_preparation_end_time(*args, **kwargs)
76+
elif event == measurement.Event.START_DATA_LOADING:
77+
self._recorder.record_data_loading_start_time(*args, **kwargs)
78+
elif event == measurement.Event.END_DATA_LOADING:
79+
self._recorder.record_data_loading_end_time(*args, **kwargs)
5280
else:
5381
logging.log_first_n(
5482
logging.WARNING,
5583
"Ignoring unknown event %s",
5684
1,
5785
event,
5886
)
87+
88+
def start_monitoring(self, *args, **kwargs):
89+
"""Starts Monitoring of Goodput.
90+
91+
Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
92+
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
93+
directory.
94+
Note: This function requires initialization of distributed JAX before it is called.
95+
If there are internal GCP errors from querying and uploading data, these will be
96+
logged without affecting the workload. GoodputMonitor logs will provide further
97+
information if data is not being uploaded correctly.
98+
"""
99+
if self._monitor is None:
100+
cfg: GoodputRecorder.Config = self.config
101+
self._monitor = goodput_monitoring.GoodputMonitor(
102+
job_name=cfg.name,
103+
logger_name=f"goodput_logger_{cfg.name}",
104+
tensorboard_dir=cfg.upload_dir,
105+
upload_interval=int(cfg.upload_interval),
106+
monitoring_enabled=(jax.process_index() == 0),
107+
include_badput_breakdown=True,
108+
)
109+
110+
self._monitor.start_goodput_uploader(*args, **kwargs)
111+
logging.info("Started Goodput upload to Tensorboard in the background!")

axlearn/cloud/gcp/measurement_test.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
class GoodputRecorderTest(parameterized.TestCase):
1717
"""Tests GoodputRecorder."""
1818

19-
@parameterized.parameters(None, ["name=test-name"])
19+
@parameterized.parameters(
20+
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
21+
)
2022
def test_from_flags(self, spec):
2123
fv = flags.FlagValues()
2224
measurement.define_flags(flag_values=fv)
@@ -34,13 +36,78 @@ def test_from_flags(self, spec):
3436
# Recorder is not instantiated until first event.
3537
self.assertIsNone(recorder._recorder)
3638

37-
def test_record(self):
39+
def test_record_and_monitor(self):
3840
fv = flags.FlagValues()
3941
measurement.define_flags(flag_values=fv)
40-
fv.set_default("recorder_spec", ["name=test-name"])
42+
fv.set_default(
43+
"recorder_spec",
44+
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
45+
)
4146
fv.mark_as_parsed()
4247

4348
recorder = GoodputRecorder.from_flags(fv)
4449
recorder._recorder = mock.MagicMock()
4550
recorder.record(measurement.Event.START_JOB)
4651
self.assertTrue(recorder._recorder.record_job_start_time.called)
52+
53+
def test_start_monitoring(self):
54+
fv = flags.FlagValues()
55+
measurement.define_flags(flag_values=fv)
56+
fv.set_default(
57+
"recorder_spec",
58+
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
59+
)
60+
fv.mark_as_parsed()
61+
62+
recorder = GoodputRecorder.from_flags(fv)
63+
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None
64+
65+
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
66+
mock_monitor_instance = mock_goodput_monitor.return_value
67+
recorder.start_monitoring()
68+
69+
# Check that GoodputMonitor was instantiated
70+
mock_goodput_monitor.assert_called_once_with(
71+
job_name="test-name",
72+
logger_name="goodput_logger_test-name",
73+
tensorboard_dir="/test/path/to/upload",
74+
upload_interval=15,
75+
monitoring_enabled=True,
76+
include_badput_breakdown=True,
77+
)
78+
79+
# Ensure that start_goodput_uploader is called on the monitor instance
80+
mock_monitor_instance.start_goodput_uploader.assert_called_once()
81+
self.assertIsNotNone(recorder._monitor)
82+
83+
def test_missing_required_flags(self):
84+
fv = flags.FlagValues()
85+
measurement.define_flags(flag_values=fv)
86+
# Missing 'upload_dir' and 'upload_interval' from recorder_spec
87+
fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config
88+
fv.mark_as_parsed()
89+
90+
# Expecting ValueError since 'upload_dir' and 'upload_interval' are required
91+
with self.assertRaises(ValueError):
92+
GoodputRecorder.from_flags(fv)
93+
94+
def test_monitoring_initialization_failure(self):
95+
fv = flags.FlagValues()
96+
measurement.define_flags(flag_values=fv)
97+
fv.set_default(
98+
"recorder_spec",
99+
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
100+
)
101+
fv.mark_as_parsed()
102+
103+
recorder = GoodputRecorder.from_flags(fv)
104+
self.assertIsNone(recorder._monitor)
105+
106+
# Mock a failure in initializing the GoodputMonitor
107+
with mock.patch(
108+
"ml_goodput_measurement.monitoring.GoodputMonitor",
109+
side_effect=Exception("Failed to initialize GoodputMonitor"),
110+
):
111+
with self.assertRaises(Exception):
112+
recorder.start_monitoring()
113+
self.assertIsNone(recorder._monitor)

axlearn/common/launch_trainer_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def main(_):
1313
launch.setup()
1414
trainer_config = launch_trainer.get_trainer_config()
1515
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
16+
measurement.start_monitoring()
1617
launch_trainer.run_trainer(trainer_config)
1718

1819

axlearn/common/measurement.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,23 @@ class Event(enum.Enum):
1818
START_JOB: Start of job.
1919
END_JOB: End of job.
2020
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
21+
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
22+
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
23+
START_TRAINING_PREPARATION: Start of training preparation.
24+
END_TRAINING_PREPARATION: End of training preparation.
25+
START_DATA_LOADING: Start of data loading.
26+
END_DATA_LOADING: End of data loading.
2127
"""
2228

2329
START_JOB = "START_JOB"
2430
END_JOB = "END_JOB"
2531
START_STEP = "START_STEP"
32+
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
33+
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
34+
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
35+
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
36+
START_DATA_LOADING = "START_DATA_LOADING"
37+
END_DATA_LOADING = "END_DATA_LOADING"
2638

2739

2840
class Recorder(Configurable):
@@ -47,6 +59,10 @@ def record(self, event: Event, *args, **kwargs):
4759
"""Records an event with the given name."""
4860
raise NotImplementedError(type(self))
4961

62+
def start_monitoring(self, **kwargs):
63+
"""Starts computing and uploading metrics at some configured interval in the background."""
64+
raise NotImplementedError(type(self))
65+
5066

5167
_recorders: dict[str, type] = {}
5268
_T = TypeVar("_T")
@@ -120,3 +136,16 @@ def record_event(event: Event):
120136
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
121137
else:
122138
global_recorder.record(event)
139+
140+
141+
def start_monitoring():
142+
"""Begins monitoring events as per global monitor functionality."""
143+
if global_recorder is None:
144+
logging.log_first_n(
145+
logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
146+
)
147+
else:
148+
global_recorder.start_monitoring()
149+
logging.info(
150+
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
151+
)

axlearn/common/measurement_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
8585
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
8686
measurement.record_event(measurement.Event.START_JOB)
8787
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])
88+
89+
# Ensure that start_monitoring does not fail.
90+
with mock.patch.object(
91+
measurement.global_recorder, "start_monitoring"
92+
) as mock_start_monitoring:
93+
measurement.start_monitoring()
94+
mock_start_monitoring.assert_called_once()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ gcp = [
9696
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
9797
"google-cloud-core==2.3.3",
9898
"google-cloud-build==3.24.1",
99-
"ml_goodput_measurement==0.0.2",
99+
"ml-goodput-measurement==0.0.4",
100100
"pika==1.3.2", # used by event queue
101101
"pyOpenSSL>=22.1.0", # compat with cryptography version.
102102
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info

0 commit comments

Comments
 (0)