Skip to content

Commit

Permalink
Use time.tokens for speedmonitor instead of dataset length (#2762)
Browse files Browse the repository at this point in the history
* change token math

* tokens

* add test

* fix tests
  • Loading branch information
mvpatel2000 authored Dec 7, 2023
1 parent e87c06d commit 7f55b7a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
24 changes: 11 additions & 13 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ class SpeedMonitor(Callback):
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
| | Only logged when dataloader.dataset has `max_seq_len`. |
| | This may include padding depending on dataset |
| | Only logged if dataspec returns tokens per batch |
+-------------------------------------+-----------------------------------------------------------+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
| `throughput/flops_per_sec` | if model has attribute `flops_per_batch` |
Expand All @@ -186,8 +185,8 @@ class SpeedMonitor(Callback):
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/tokens_per_sec` divided by world size. Only |
| `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This |
| | may include pad tokens depending on dataset |
| `throughput/device/tokens_per_sec` | logged if dataspec returns tokens per batch |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/flops_per_sec` divided by world size. Only |
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
Expand Down Expand Up @@ -222,6 +221,7 @@ def __init__(
):
# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_tokens: Deque[int] = deque(maxlen=window_size + 1)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
self.history_flops: Deque[float] = deque(maxlen=window_size + 1)

Expand Down Expand Up @@ -259,13 +259,15 @@ def init(self, state: State, logger: Logger) -> None:
def batch_end(self, state: State, logger: Logger):
# Add the new element
self.history_samples.append(state.timestamp.sample.value)
self.history_tokens.append(state.timestamp.token.value)
self.history_wct.append(state.timestamp.total_wct.total_seconds())

# Log the throughput
if len(self.history_wct) == self.history_wct.maxlen:
world_size = dist.get_world_size()
elapsed_batches = len(self.history_samples) - 1
elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0])
elapsed_tokens = int(self.history_tokens[-1]) - int(self.history_tokens[0])
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
batches_per_sec = elapsed_batches / elapsed_wct
samples_per_sec = elapsed_samples / elapsed_wct
Expand All @@ -277,17 +279,13 @@ def batch_end(self, state: State, logger: Logger):
'throughput/device/batches_per_sec': dev_batches_per_sec,
'throughput/device/samples_per_sec': dev_samples_per_sec,
})

# Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding.
try:
max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore
# Only applicable to seq data / models
if elapsed_tokens > 0:
tokens_per_sec = elapsed_tokens / elapsed_wct
dev_tokens_per_sec = tokens_per_sec / world_size
logger.log_metrics({
'throughput/tokens_per_sec': samples_per_sec * max_seq_len,
'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len,
'throughput/tokens_per_sec': tokens_per_sec,
'throughput/device/tokens_per_sec': dev_tokens_per_sec,
})
except AttributeError:
pass

# Compute flops stats if model has flops_per_batch
composer_model = state.model
Expand Down
49 changes: 48 additions & 1 deletion tests/callbacks/test_speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from composer.core import Time
from composer.loggers import InMemoryLogger
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel
from tests.common import RandomClassificationDataset, SimpleModel, SimpleTransformerClassifier
from tests.common.datasets import dummy_text_classification_dataloader


def _assert_no_negative_values(logged_values):
Expand Down Expand Up @@ -53,6 +54,8 @@ def test_speed_monitor(flops_per_batch: bool):
_assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec'])
assert 'throughput/tokens_per_sec' not in in_memory_logger.data
assert 'throughput/device/tokens_per_sec' not in in_memory_logger.data
if flops_per_batch:
_assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec'])
Expand All @@ -73,3 +76,47 @@ def test_speed_monitor(flops_per_batch: bool):
assert len(in_memory_logger.data['time/total']) == num_batches
assert len(in_memory_logger.data['time/train']) == num_batches
assert len(in_memory_logger.data['time/val']) == num_batches


def test_speed_monitor_tokens():
model = SimpleTransformerClassifier()
dataloader = dummy_text_classification_dataloader()
dataloader.dataset.max_seq_len = dataloader.dataset.sequence_length # type: ignore
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
speed_monitor = SpeedMonitor(window_size=1)
trainer = Trainer(
model=model,
train_dataloader=dataloader,
callbacks=speed_monitor,
loggers=in_memory_logger,
max_duration='1ep',
)
trainer.fit()

print(list(in_memory_logger.data.keys()))

_assert_no_negative_values(in_memory_logger.data['time/train'])
_assert_no_negative_values(in_memory_logger.data['time/val'])
_assert_no_negative_values(in_memory_logger.data['time/total'])
_assert_no_negative_values(in_memory_logger.data['throughput/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/tokens_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec'])
_assert_no_negative_values(in_memory_logger.data['throughput/device/tokens_per_sec'])

assert isinstance(trainer.state.dataloader, collections.abc.Sized)
assert trainer.state.dataloader_label is not None
assert trainer.state.dataloader_len is not None
expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int(
trainer.state.timestamp.epoch)
assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/tokens_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls
assert len(in_memory_logger.data['throughput/device/tokens_per_sec']) == expected_step_calls
num_batches = int(trainer.state.timestamp.batch)
assert len(in_memory_logger.data['time/total']) == num_batches
assert len(in_memory_logger.data['time/train']) == num_batches
assert len(in_memory_logger.data['time/val']) == num_batches
14 changes: 7 additions & 7 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1256,7 +1256,7 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g
@device('gpu')
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def test_code_eval_requires_valid_envvar(monkeypatch):
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token
@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_t5_tokenizer,
tiny_t5_model, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down Expand Up @@ -1413,7 +1413,7 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('generations_per_sample', [1])
@pytest.mark.filterwarnings(r'ignore: Input length of input_ids is')
@pytest.mark.filterwarnings(r'ignore:The dataloader_len \(2\) is greater than the length.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer,
tiny_gpt2_model, tmp_path, generations_per_sample):
pytest.importorskip('datasets')
Expand Down

0 comments on commit 7f55b7a

Please sign in to comment.