Skip to content

Commit

Permalink
Record training loss (#539)
Browse files Browse the repository at this point in the history
Per the PR title, this PR solves issue
#488.
  • Loading branch information
XianzheMa authored Jun 24, 2024
1 parent 7d694d0 commit b9e255e
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 50 deletions.
60 changes: 29 additions & 31 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,40 +216,38 @@ def prepare_start_training_request(

grad_scaler_config = training_config.grad_scaler_config if training_config.grad_scaler_config else {}

start_training_kwargs = {
"pipeline_id": pipeline_id,
"trigger_id": trigger_id,
"device": training_config.device,
"use_pretrained_model": previous_model_id is not None,
"pretrained_model_id": previous_model_id or -1,
"load_optimizer_state": False, # TODO(#137): Think about this.
"batch_size": training_config.batch_size,
"torch_optimizers_configuration": TrainerServerJsonString(value=json.dumps(optimizers_config)),
"torch_criterion": training_config.optimization_criterion.name,
"criterion_parameters": TrainerServerJsonString(value=criterion_config),
"data_info": Data(
return StartTrainingRequest(
pipeline_id=pipeline_id,
trigger_id=trigger_id,
device=training_config.device,
use_pretrained_model=previous_model_id is not None,
pretrained_model_id=previous_model_id or -1,
load_optimizer_state=False, # TODO(#137): Think about this.
batch_size=training_config.batch_size,
torch_optimizers_configuration=TrainerServerJsonString(value=json.dumps(optimizers_config)),
torch_criterion=training_config.optimization_criterion.name,
criterion_parameters=TrainerServerJsonString(value=criterion_config),
data_info=Data(
dataset_id=data_config.dataset_id,
num_dataloaders=training_config.dataloader_workers,
),
"checkpoint_info": checkpoint_info,
"transform_list": data_config.transformations,
"bytes_parser": PythonString(value=data_config.bytes_parser_function),
"label_transformer": PythonString(value=data_config.label_transformer_function),
"lr_scheduler": TrainerServerJsonString(value=json.dumps(lr_scheduler_configs)),
"grad_scaler_configuration": TrainerServerJsonString(value=json.dumps(grad_scaler_config)),
"epochs_per_trigger": training_config.epochs_per_trigger,
"num_prefetched_partitions": training_config.num_prefetched_partitions,
"parallel_prefetch_requests": training_config.parallel_prefetch_requests,
"seed": training_config.seed,
"tokenizer": PythonString(value=tokenizer) if tokenizer is not None else None,
"num_samples_to_pass": num_samples_to_pass,
"shuffle": training_config.shuffle,
"enable_accurate_gpu_measurements": training_config.enable_accurate_gpu_measurements,
}

cleaned_kwargs: dict[str, Any] = {k: v for k, v in start_training_kwargs.items() if v is not None}

return StartTrainingRequest(**cleaned_kwargs)
checkpoint_info=checkpoint_info,
transform_list=data_config.transformations,
bytes_parser=PythonString(value=data_config.bytes_parser_function),
label_transformer=PythonString(value=data_config.label_transformer_function),
lr_scheduler=TrainerServerJsonString(value=json.dumps(lr_scheduler_configs)),
grad_scaler_configuration=TrainerServerJsonString(value=json.dumps(grad_scaler_config)),
epochs_per_trigger=training_config.epochs_per_trigger,
num_prefetched_partitions=training_config.num_prefetched_partitions,
parallel_prefetch_requests=training_config.parallel_prefetch_requests,
seed=training_config.seed, # seed is an optional field which can accept None
# tokenizer is an optional field which can accept None
tokenizer=PythonString(value=tokenizer) if tokenizer is not None else None,
num_samples_to_pass=num_samples_to_pass if num_samples_to_pass is not None else 0,
shuffle=training_config.shuffle,
enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements,
record_loss_every=training_config.record_loss_every,
)

def start_training(
self,
Expand Down
4 changes: 4 additions & 0 deletions modyn/config/schema/pipeline/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ class TrainingConfig(ModynBaseModel):
description="The ID of the model that should be used as the initial model.",
)
checkpointing: CheckpointingConfig = Field(description="Configuration of checkpointing during training")
record_loss_every: int = Field(
default=0,
description="Record the training loss in the trainer_log very n-th batch/step. If 0, loss is not recorded.",
)
optimizers: List[OptimizerConfig] = Field(
description="An array of the optimizers for the training",
min_length=1,
Expand Down
1 change: 1 addition & 0 deletions modyn/protos/trainer_server.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ message StartTrainingRequest {
int64 num_samples_to_pass = 23;
bool shuffle = 24;
bool enable_accurate_gpu_measurements = 25;
int64 record_loss_every = 26;
}

message StartTrainingResponse {
Expand Down
3 changes: 3 additions & 0 deletions modyn/tests/common/grpc/test_grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def test_prepare_start_training_request(
# for bool value False is the default value so we don't need to test it
pipeline_training_config.shuffle = True
pipeline_training_config.enable_accurate_gpu_measurements = True
# for int value 0 is the default value so we don't need to test it
pipeline_training_config.record_loss_every = 10
pipeline_training_config.optimization_criterion.config = {"key": "value"}
pipeline_training_config.use_previous_model = previous_model_id is not None

Expand Down Expand Up @@ -136,6 +138,7 @@ def test_prepare_start_training_request(
assert req.num_samples_to_pass == (num_samples_to_pass if num_samples_to_pass is not None else 0)
assert req.shuffle
assert req.enable_accurate_gpu_measurements
assert req.record_loss_every == 10


@patch("modyn.common.grpc.grpc_helpers.grpc_connection_established", return_value=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def test_train(
status = status_queue.get()
assert status["num_batches"] == 0
assert status["num_samples"] == 0
# we didn't enable recording the training loss
assert len(trainer._log["training_loss"]) == 0
status_state = torch.load(io.BytesIO(status_queue.get()))
checkpointed_state = {
"model": OrderedDict(
Expand Down
36 changes: 18 additions & 18 deletions modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b9e255e

Please sign in to comment.