Skip to content

Commit

Permalink
Forward shuffle to trainer server and measure GPU (#526)
Browse files Browse the repository at this point in the history
1. Add a unit test to how we pack the `StartTrainingRequest`, thereby
catching the unspecified `shuffle`.
2. Enable accurate GPU operations measurement.
  • Loading branch information
XianzheMa authored Jun 19, 2024
1 parent 84dac1d commit b046892
Show file tree
Hide file tree
Showing 12 changed files with 422 additions and 163 deletions.
57 changes: 34 additions & 23 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,16 @@ def stop_training_at_trainer_server(self, training_id: int) -> None:
# TODO(#130): Implement this at trainer server.
logger.error("The trainer server currently does not support remotely stopping training, ignoring.")

# pylint: disable=too-many-branches,too-many-locals,too-many-statements
def start_training(
self,
# pylint: disable=too-many-locals, too-many-branches
@staticmethod
def prepare_start_training_request(
pipeline_id: int,
trigger_id: int,
training_config: TrainingConfig,
data_config: DataConfig,
previous_model_id: Optional[int],
num_samples_to_pass: Optional[int] = None,
) -> int:
assert self.trainer_server is not None
if not self.connected_to_trainer_server:
raise ConnectionError("Tried to start training at trainer server, but not there is no gRPC connection.")

) -> StartTrainingRequest:
optimizers_config = {}
for optimizer in training_config.optimizers:
optimizer_config: dict[str, Any] = {
Expand All @@ -207,15 +203,7 @@ def start_training(
lr_scheduler_configs = training_config.lr_scheduler.model_dump(by_alias=True)

criterion_config = json.dumps(training_config.optimization_criterion.config)

epochs_per_trigger = training_config.epochs_per_trigger
num_prefetched_partitions = training_config.num_prefetched_partitions
parallel_prefetch_requests = training_config.parallel_prefetch_requests

seed = training_config.seed
tokenizer = data_config.tokenizer
transform_list = data_config.transformations
label_transformer = data_config.label_transformer_function

if training_config.checkpointing.activated:
# the None-ility of the two fields are already validated by pydantic
Expand Down Expand Up @@ -244,23 +232,46 @@ def start_training(
num_dataloaders=training_config.dataloader_workers,
),
"checkpoint_info": checkpoint_info,
"transform_list": transform_list,
"transform_list": data_config.transformations,
"bytes_parser": PythonString(value=data_config.bytes_parser_function),
"label_transformer": PythonString(value=label_transformer),
"label_transformer": PythonString(value=data_config.label_transformer_function),
"lr_scheduler": TrainerServerJsonString(value=json.dumps(lr_scheduler_configs)),
"grad_scaler_configuration": TrainerServerJsonString(value=json.dumps(grad_scaler_config)),
"epochs_per_trigger": epochs_per_trigger,
"num_prefetched_partitions": num_prefetched_partitions,
"parallel_prefetch_requests": parallel_prefetch_requests,
"seed": seed,
"epochs_per_trigger": training_config.epochs_per_trigger,
"num_prefetched_partitions": training_config.num_prefetched_partitions,
"parallel_prefetch_requests": training_config.parallel_prefetch_requests,
"seed": training_config.seed,
"tokenizer": PythonString(value=tokenizer) if tokenizer is not None else None,
"num_samples_to_pass": num_samples_to_pass,
"shuffle": training_config.shuffle,
"enable_accurate_gpu_measurements": training_config.enable_accurate_gpu_measurements,
}

cleaned_kwargs: dict[str, Any] = {k: v for k, v in start_training_kwargs.items() if v is not None}

req = StartTrainingRequest(**cleaned_kwargs)
return StartTrainingRequest(**cleaned_kwargs)

def start_training(
self,
pipeline_id: int,
trigger_id: int,
training_config: TrainingConfig,
data_config: DataConfig,
previous_model_id: Optional[int],
num_samples_to_pass: Optional[int] = None,
) -> int:
assert self.trainer_server is not None
if not self.connected_to_trainer_server:
raise ConnectionError("Tried to start training at trainer server, but not there is no gRPC connection.")

req = self.prepare_start_training_request(
pipeline_id,
trigger_id,
training_config,
data_config,
previous_model_id,
num_samples_to_pass,
)
response: StartTrainingResponse = self.trainer_server.start_training(req)

if not response.training_started:
Expand Down
5 changes: 5 additions & 0 deletions modyn/config/schema/pipeline/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class TrainingConfig(ModynBaseModel):
"Otherwise, the output order is deterministic."
)
)
enable_accurate_gpu_measurements: bool = Field(
default=False,
description="If True, we measure the time of individual GPU related operations within a training process more "
"accurately by cuda synchronization. Note this can have a significant impact on performance on training.",
)
use_previous_model: bool = Field(
description=(
"If True, on trigger, we continue training on the model outputted by the previous trigger. If False, "
Expand Down
1 change: 1 addition & 0 deletions modyn/protos/trainer_server.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ message StartTrainingRequest {
optional PythonString tokenizer = 22;
int64 num_samples_to_pass = 23;
bool shuffle = 24;
bool enable_accurate_gpu_measurements = 25;
}

message StartTrainingResponse {
Expand Down
108 changes: 107 additions & 1 deletion modyn/tests/common/grpc/test_grpc_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import json
import multiprocessing as mp
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
from modyn.common.grpc import GenericGRPCServer
from modyn.common.grpc.grpc_helpers import TrainerServerGRPCHandlerMixin
from modyn.config import ModynConfig, ModynPipelineConfig
from modyn.config import (
CheckpointingConfig,
DataConfig,
LrSchedulerConfig,
ModynConfig,
ModynPipelineConfig,
TrainingConfig,
)
from modyn.supervisor.internal.utils import TrainingStatusReporter
from modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 import JsonString as TrainerJsonString
from modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 import (
Expand Down Expand Up @@ -41,8 +52,100 @@ def test_init_and_trainer_server_available(
avail_method.assert_called_once()


@pytest.fixture()
def pipeline_data_config():
return DataConfig(
dataset_id="test",
bytes_parser_function="def bytes_parser_function(x):\n\treturn x",
label_transformer_function="def label_transformer_function(x):\n\treturn x",
transformations=["transformation1", "transformation2"],
)


@pytest.fixture()
def lr_scheduler_config():
return LrSchedulerConfig(
name="CosineAnnealingLR",
source="PyTorch",
step_every="batch",
optimizers=["default"],
config={"T_max": "MODYN_NUM_BATCHES", "eta_min": 0.001},
)


@pytest.mark.parametrize("previous_model_id", [1, None])
@pytest.mark.parametrize("num_samples_to_pass", [5, None])
@pytest.mark.parametrize("set_lr_scheduler_to_none", [True, False])
@pytest.mark.parametrize("disable_checkpointing", [True, False])
def test_prepare_start_training_request(
disable_checkpointing: bool,
set_lr_scheduler_to_none: bool,
num_samples_to_pass: Optional[int],
previous_model_id: Optional[int],
pipeline_training_config: TrainingConfig,
pipeline_data_config: DataConfig,
lr_scheduler_config: LrSchedulerConfig,
):
# for bool value False is the default value so we don't need to test it
pipeline_training_config.shuffle = True
pipeline_training_config.enable_accurate_gpu_measurements = True
pipeline_training_config.optimization_criterion.config = {"key": "value"}
pipeline_training_config.use_previous_model = previous_model_id is not None

pipeline_training_config.lr_scheduler = None if set_lr_scheduler_to_none else lr_scheduler_config
if set_lr_scheduler_to_none:
expected_lr_scheduler_config = {}
else:
expected_lr_scheduler_config = lr_scheduler_config.model_dump(by_alias=True)
if disable_checkpointing:
pipeline_training_config.checkpointing = CheckpointingConfig(activated=False)
else:
pipeline_training_config.checkpointing = CheckpointingConfig(activated=True, interval=1, path=Path("test"))

pipeline_id = 42
trigger_id = 21

req = TrainerServerGRPCHandlerMixin.prepare_start_training_request(
pipeline_id, trigger_id, pipeline_training_config, pipeline_data_config, previous_model_id, num_samples_to_pass
)

assert req.pipeline_id == pipeline_id
assert req.trigger_id == trigger_id
assert req.device == pipeline_training_config.device
assert req.use_pretrained_model == pipeline_training_config.use_previous_model
assert req.pretrained_model_id == previous_model_id if previous_model_id is not None else -1
assert req.batch_size == pipeline_training_config.batch_size
assert req.torch_criterion == pipeline_training_config.optimization_criterion.name
assert json.loads(req.criterion_parameters.value) == pipeline_training_config.optimization_criterion.config
assert req.data_info.dataset_id == pipeline_data_config.dataset_id
assert req.data_info.num_dataloaders == pipeline_training_config.dataloader_workers
if disable_checkpointing:
assert req.checkpoint_info.checkpoint_path == ""
assert req.checkpoint_info.checkpoint_interval == 0
else:
assert req.checkpoint_info.checkpoint_path == str(pipeline_training_config.checkpointing.path)
assert req.checkpoint_info.checkpoint_interval == pipeline_training_config.checkpointing.interval
assert req.bytes_parser.value == pipeline_data_config.bytes_parser_function
assert req.transform_list == pipeline_data_config.transformations
assert req.label_transformer.value == pipeline_data_config.label_transformer_function
assert json.loads(req.lr_scheduler.value) == expected_lr_scheduler_config
assert req.epochs_per_trigger == pipeline_training_config.epochs_per_trigger
assert req.num_prefetched_partitions == pipeline_training_config.num_prefetched_partitions
assert req.parallel_prefetch_requests == pipeline_training_config.parallel_prefetch_requests
assert req.seed == 0
assert req.num_samples_to_pass == (num_samples_to_pass if num_samples_to_pass is not None else 0)
assert req.shuffle
assert req.measure_operation_time


@patch("modyn.common.grpc.grpc_helpers.grpc_connection_established", return_value=True)
@patch.object(
TrainerServerGRPCHandlerMixin,
"prepare_start_training_request",
wraps=TrainerServerGRPCHandlerMixin.prepare_start_training_request,
)
def test_start_training(
test_prepare_start_training_request: MagicMock,
test_common_connection_established: MagicMock,
dummy_pipeline_config: ModynPipelineConfig,
dummy_system_config: ModynConfig,
Expand All @@ -65,6 +168,9 @@ def test_start_training(
== 42
)
avail_method.assert_called_once()
test_prepare_start_training_request.assert_called_once_with(
pipeline_id, trigger_id, dummy_pipeline_config.training, dummy_pipeline_config.data, None, None
)


@patch("modyn.common.grpc.grpc_helpers.grpc_connection_established", return_value=True)
Expand Down
34 changes: 34 additions & 0 deletions modyn/tests/trainer_server/internal/test_gpu_measurement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import time
from unittest.mock import patch

from modyn.common.benchmark import Stopwatch
from modyn.trainer_server.internal.trainer.gpu_measurement import GPUMeasurement


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.torch.cuda.synchronize")
def test_gpu_measurement(cuda_synchronize_mock):
stopwatch = Stopwatch()
with (
patch.object(Stopwatch, "stop", wraps=stopwatch.stop) as stopwatch_stop_mock,
patch.object(Stopwatch, "start", wraps=stopwatch.start) as stopwatch_start_mock,
):

with GPUMeasurement(True, "measure", "cpu", stopwatch, resume=True):
time.sleep(1)

stopwatch_start_mock.assert_called_once_with(name="measure", resume=True)
stopwatch_stop_mock.assert_called_once_with(name="measure")
assert cuda_synchronize_mock.call_count == 2
assert 1000 <= stopwatch.measurements["measure"] <= 1100

stopwatch_start_mock.reset_mock()
stopwatch_stop_mock.reset_mock()
cuda_synchronize_mock.reset_mock()
with GPUMeasurement(False, "measure2", "cpu", stopwatch, overwrite=False):
pass

stopwatch_start_mock.assert_called_once_with(name="measure2", overwrite=False)
stopwatch_stop_mock.assert_called_once_with(name="measure2")
assert cuda_synchronize_mock.call_count == 0
# we still want to take the (inaccurate) measurement
assert "measure2" in stopwatch.measurements
Loading

0 comments on commit b046892

Please sign in to comment.