Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MFU Monitoring to LLM #56

Merged
merged 32 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings

from composer import Trainer
from composer.callbacks import LRMonitor, MemoryMonitor, SpeedMonitor
from composer.callbacks import LRMonitor, MemoryMonitor
from composer.loggers import WandBLogger
from composer.optim import DecoupledAdamW
from composer.optim.scheduler import (ConstantWithWarmupScheduler,
Expand All @@ -16,6 +16,8 @@
from src.text_data import build_text_dataloader
from src.model_registry import COMPOSER_MODEL_REGISTRY

from src.speed_monitor_w_mfu import SpeedMonitorMFU


def build_logger(name, kwargs):
if name == 'wandb':
Expand All @@ -24,13 +26,30 @@ def build_logger(name, kwargs):
raise ValueError(f'Not sure how to build logger: {name}')


def build_callback(name, kwargs):
def get_model_fwd_flops(cfg):
if cfg.model.name == 'mosaic_gpt':
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
# this gets us FLOPs / token
params_flops_per_token = 2 * cfg.n_params
params_flops_per_seq = params_flops_per_token * cfg.max_seq_len
# there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
attn_flops_per_seq = cfg.model.n_layers * 2 * 2 * (cfg.model.d_model * (cfg.max_seq_len ** 2))
return params_flops_per_seq + attn_flops_per_seq

return None
vchiley marked this conversation as resolved.
Show resolved Hide resolved


def build_callback(name, kwargs, cfg):
if name == 'lr_monitor':
return LRMonitor()
elif name == 'memory_monitor':
return MemoryMonitor()
elif name == 'speed_monitor':
return SpeedMonitor(window_size=kwargs.get('window_size', 1))
return SpeedMonitorMFU(
window_size=kwargs.get('window_size', 1),
model_flops=get_model_fwd_flops(cfg),
)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down Expand Up @@ -167,7 +186,7 @@ def main(cfg):

# Callbacks
callbacks = [
build_callback(name, callback_cfg)
build_callback(name, callback_cfg, cfg)
for name, callback_cfg in cfg.get('callbacks', {}).items()
]

Expand Down
100 changes: 100 additions & 0 deletions llm/src/speed_monitor_w_mfu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Monitor throughput during training."""
from __future__ import annotations

from collections import deque
from typing import Any, Deque, Dict

from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

from composer.callbacks import SpeedMonitor

GPU_AVAILABLE_FLOPS = 312_000_000_000_000
vchiley marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ['SpeedMonitorMFU']


class SpeedMonitorMFU(SpeedMonitor):
"""Logs the training throughput and MFU.

The training throughput in terms of number of samples per second is logged on the
:attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold.

The wall clock train time is logged on every :attr:`.Event.BATCH_END` event.

The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event.

Example:
.. doctest::

>>> from composer import Trainer
>>> from composer.callbacks import SpeedMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[SpeedMonitor(window_size=100, model_flops=<MODEL_FWD_FLOPS>)],
... )

The training throughput is logged by the :class:`.Logger` to the following keys as
described below.

+----------------------------------+-------------------------------------------------------------+
| Key | Logged data |
+==================================+=============================================================+
| | Rolling average (over ``window_size`` most recent |
| ``throughput/samples_per_sec`` | batches) of the number of samples processed per second |
| | |
+----------------------------------+-------------------------------------------------------------+
| | Rolling average (over ``window_size`` most recent |
| ``throughput/mfu`` | batches) of mfu |
| | |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/train`` | Total elapsed training time |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/val`` | Total elapsed validation time |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) |
+----------------------------------+-------------------------------------------------------------+

Args:
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Defaults to 100.
model_flops (int, optional): Number of FLOPs model uses in the fwd pass per sample. If used, callback will track training MFU.
Defaults to None (disabled)
"""

def __init__(self, window_size: int = 100, model_flops: int = None):
super().__init__(window_size=window_size)
self.model_flops = model_flops

def batch_end(self, state: State, logger: Logger):
batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples
batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct

# Add the new element
self.batch_wct_buffer.append(batch_wct)
self.batch_num_samples_buffer.append(batch_num_samples)

# Log the throughput
if len(self.batch_num_samples_buffer) == self.window_size:
throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer)
logger.log_metrics({'throughput/samples_per_sec': throughput})
abhi-mosaic marked this conversation as resolved.
Show resolved Hide resolved
if self.model_flops is not None:
mfu = 3 * self.model_flops * throughput / (dist.get_world_size() * GPU_AVAILABLE_FLOPS)
logger.log_metrics({'throughput/mfu': mfu})

# Log the time
# `state.timestamp` excludes any time spent in evaluation
logger.log_metrics({
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
'wall_clock/val': self.total_eval_wct,
'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct),
})