Skip to content

Add Goodput & Badput recording and monitoring support. #783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,39 @@
import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import measurement
from axlearn.common.config import maybe_set_config
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config


@measurement.register_recorder("goodput")
class GoodputRecorder(measurement.Recorder):
"""Records overall training goodput."""

Config = measurement.Recorder.Config
@config_class
class Config(measurement.Recorder.Config):
"""Configures GoodputRecorder.

Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
"""Converts flags to a recorder.

`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values.
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
Expand All @@ -32,6 +47,7 @@ def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
Expand All @@ -49,10 +65,47 @@ def record(self, event: measurement.Event, *args, **kwargs):
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOI, is there anything here specific to TPUs or can we use the same API for GPUs on GCP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not specific to TPUs, since this badput bucket is computed based on recorded markers. The API name needs to be updated to reflect this and whenever this is done, we would need to refactor this piece of code as well.

elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.

Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Note: This function requires initialization of distributed JAX before it is called.
If there are internal GCP errors from querying and uploading data, these will be
logged without affecting the workload. GoodputMonitor logs will provide further
information if data is not being uploaded correctly.
"""
if self._monitor is None:
cfg: GoodputRecorder.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)

self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
73 changes: 70 additions & 3 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class GoodputRecorderTest(parameterized.TestCase):
"""Tests GoodputRecorder."""

@parameterized.parameters(None, ["name=test-name"])
@parameterized.parameters(
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
)
def test_from_flags(self, spec):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
Expand All @@ -34,13 +36,78 @@ def test_from_flags(self, spec):
# Recorder is not instantiated until first event.
self.assertIsNone(recorder._recorder)

def test_record(self):
def test_record_and_monitor(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default("recorder_spec", ["name=test-name"])
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._recorder = mock.MagicMock()
recorder.record(measurement.Event.START_JOB)
self.assertTrue(recorder._recorder.record_job_start_time.called)

def test_start_monitoring(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
mock_monitor_instance = mock_goodput_monitor.return_value
recorder.start_monitoring()

# Check that GoodputMonitor was instantiated
mock_goodput_monitor.assert_called_once_with(
job_name="test-name",
logger_name="goodput_logger_test-name",
tensorboard_dir="/test/path/to/upload",
upload_interval=15,
monitoring_enabled=True,
include_badput_breakdown=True,
)

# Ensure that start_goodput_uploader is called on the monitor instance
mock_monitor_instance.start_goodput_uploader.assert_called_once()
self.assertIsNotNone(recorder._monitor)

def test_missing_required_flags(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
# Missing 'upload_dir' and 'upload_interval' from recorder_spec
fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config
fv.mark_as_parsed()

# Expecting ValueError since 'upload_dir' and 'upload_interval' are required
with self.assertRaises(ValueError):
GoodputRecorder.from_flags(fv)

def test_monitoring_initialization_failure(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
self.assertIsNone(recorder._monitor)

# Mock a failure in initializing the GoodputMonitor
with mock.patch(
"ml_goodput_measurement.monitoring.GoodputMonitor",
side_effect=Exception("Failed to initialize GoodputMonitor"),
):
with self.assertRaises(Exception):
recorder.start_monitoring()
self.assertIsNone(recorder._monitor)
1 change: 1 addition & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main(_):
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
launch_trainer.run_trainer(trainer_config)


Expand Down
29 changes: 29 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ class Event(enum.Enum):
START_JOB: Start of job.
END_JOB: End of job.
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
START_TRAINING_PREPARATION: Start of training preparation.
END_TRAINING_PREPARATION: End of training preparation.
START_DATA_LOADING: Start of data loading.
END_DATA_LOADING: End of data loading.
"""

START_JOB = "START_JOB"
END_JOB = "END_JOB"
START_STEP = "START_STEP"
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
START_DATA_LOADING = "START_DATA_LOADING"
END_DATA_LOADING = "END_DATA_LOADING"


class Recorder(Configurable):
Expand All @@ -47,6 +59,10 @@ def record(self, event: Event, *args, **kwargs):
"""Records an event with the given name."""
raise NotImplementedError(type(self))

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
raise NotImplementedError(type(self))


_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down Expand Up @@ -120,3 +136,16 @@ def record_event(event: Event):
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
else:
global_recorder.record(event)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_recorder is None:
logging.log_first_n(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit -- since start_monitoring is only called once, we don't need log_first_n. (Not having it may help catch when it's called multiple times.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a no-op, I'll leave this for now and address it the next PR. Expecting to complete integration with package v5 in the next one as well. Here is the tracking bug.

logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
)
else:
global_recorder.start_monitoring()
logging.info(
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
)
7 changes: 7 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
measurement.record_event(measurement.Event.START_JOB)
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])

# Ensure that start_monitoring does not fail.
with mock.patch.object(
measurement.global_recorder, "start_monitoring"
) as mock_start_monitoring:
measurement.start_monitoring()
mock_start_monitoring.assert_called_once()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml_goodput_measurement==0.0.2",
"ml-goodput-measurement==0.0.4",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info
Expand Down