Skip to content

Commit

Permalink
Parameterize drop_last (#550)
Browse files Browse the repository at this point in the history
The reason for this PR is that, for il model training, I don't want to
drop the last batch, as the holdout set is already super small (and on
il model training no sampling is used, so it does not harm not to drop
last batch). But for main model training, usually we default `drop_last`
to True.
  • Loading branch information
XianzheMa authored Jun 27, 2024
1 parent 3e86231 commit 59ea026
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 21 deletions.
1 change: 1 addition & 0 deletions integrationtests/config/rho_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ selection_strategy:
dataloader_workers: 1
use_previous_model: False
batch_size: 2
drop_last_batch: False
shuffle: False
optimizers:
- name: "default"
Expand Down
1 change: 1 addition & 0 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def prepare_start_training_request(
shuffle=training_config.shuffle,
enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements,
record_loss_every=training_config.record_loss_every,
drop_last_batch=training_config.drop_last_batch,
)

def start_training(
Expand Down
3 changes: 3 additions & 0 deletions modyn/config/schema/pipeline/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class TrainingConfig(ModynBaseModel):
description="The number of data loader workers on the trainer node that fetch data from storage.", ge=1
)
batch_size: int = Field(description="The batch size to be used during training.", ge=1)
drop_last_batch: bool = Field(
default=True, description="Whether to drop the last batch if it is smaller than the batch size."
)
shuffle: bool = Field(
description=(
"If True, we shuffle the order of partitions and the data within each partition at each worker."
Expand Down
1 change: 1 addition & 0 deletions modyn/protos/trainer_server.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ message StartTrainingRequest {
bool shuffle = 24;
bool enable_accurate_gpu_measurements = 25;
int64 record_loss_every = 26;
bool drop_last_batch = 27;
}

message StartTrainingResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def mock_get_dataloaders(
shuffle,
tokenizer,
log_path,
drop_last,
num_batches: int = 100,
):
mock_train_dataloader = MockDataloader(batch_size, num_batches)
Expand Down
36 changes: 18 additions & 18 deletions modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class StartTrainingRequest(google.protobuf.message.Message):
SHUFFLE_FIELD_NUMBER: builtins.int
ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int
RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int
DROP_LAST_BATCH_FIELD_NUMBER: builtins.int
pipeline_id: builtins.int
trigger_id: builtins.int
device: builtins.str
Expand All @@ -149,6 +150,7 @@ class StartTrainingRequest(google.protobuf.message.Message):
shuffle: builtins.bool
enable_accurate_gpu_measurements: builtins.bool
record_loss_every: builtins.int
drop_last_batch: builtins.bool
@property
def torch_optimizers_configuration(self) -> global___JsonString: ...
@property
Expand Down Expand Up @@ -198,9 +200,10 @@ class StartTrainingRequest(google.protobuf.message.Message):
shuffle: builtins.bool = ...,
enable_accurate_gpu_measurements: builtins.bool = ...,
record_loss_every: builtins.int = ...,
drop_last_batch: builtins.bool = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "lr_scheduler", b"lr_scheduler", "seed", b"seed", "tokenizer", b"tokenizer", "torch_optimizers_configuration", b"torch_optimizers_configuration"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "enable_accurate_gpu_measurements", b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "num_prefetched_partitions", b"num_prefetched_partitions", "num_samples_to_pass", b"num_samples_to_pass", "parallel_prefetch_requests", b"parallel_prefetch_requests", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "record_loss_every", b"record_loss_every", "seed", b"seed", "shuffle", b"shuffle", "tokenizer", b"tokenizer", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ...
def ClearField(self, field_name: typing.Literal["_seed", b"_seed", "_tokenizer", b"_tokenizer", "batch_size", b"batch_size", "bytes_parser", b"bytes_parser", "checkpoint_info", b"checkpoint_info", "criterion_parameters", b"criterion_parameters", "data_info", b"data_info", "device", b"device", "drop_last_batch", b"drop_last_batch", "enable_accurate_gpu_measurements", b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", "lr_scheduler", b"lr_scheduler", "num_prefetched_partitions", b"num_prefetched_partitions", "num_samples_to_pass", b"num_samples_to_pass", "parallel_prefetch_requests", b"parallel_prefetch_requests", "pipeline_id", b"pipeline_id", "pretrained_model_id", b"pretrained_model_id", "record_loss_every", b"record_loss_every", "seed", b"seed", "shuffle", b"shuffle", "tokenizer", b"tokenizer", "torch_criterion", b"torch_criterion", "torch_optimizers_configuration", b"torch_optimizers_configuration", "transform_list", b"transform_list", "trigger_id", b"trigger_id", "use_pretrained_model", b"use_pretrained_model"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ...
@typing.overload
Expand Down
10 changes: 8 additions & 2 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
self.epochs_per_trigger = training_info.epochs_per_trigger
self.num_samples_to_pass = training_info.num_samples_to_pass
self._log_file_path = training_info.log_file_path
self._drop_last_batch = training_info.drop_last_batch
self._dataset_log_path = pathlib.Path(tempfile.mkdtemp(prefix=f"pl{self.pipeline_id}"))

if not self._checkpoint_path.is_dir():
Expand Down Expand Up @@ -184,6 +185,7 @@ def __init__(
training_info.shuffle,
training_info.tokenizer,
self._dataset_log_path,
drop_last=self._drop_last_batch,
)

# Create callbacks
Expand Down Expand Up @@ -288,7 +290,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches

data, sample_ids, target, weights = batch_accumulator.get_accumulated_batch()

self._assert_data_size(self._batch_size, data, sample_ids, target)
self._assert_data_size(self._batch_size, data, sample_ids, target)

with GPUMeasurement(self._measure_gpu_ops, "Forward", self._device, stopw, resume=True):
output = self._model.model(data)
Expand Down Expand Up @@ -450,7 +452,11 @@ def downsample_trigger_training_set(self) -> None:
for label in available_labels:
if first_label:
per_class_dataloader = prepare_per_class_dataloader_from_online_dataset(
self._train_dataloader.dataset, self._batch_size, self._num_dataloaders, label
self._train_dataloader.dataset,
self._batch_size,
self._num_dataloaders,
label,
drop_last=self._drop_last_batch,
)
first_label = False
else:
Expand Down
1 change: 1 addition & 0 deletions modyn/trainer_server/internal/utils/training_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
), "Inconsistent pretrained model configuration"

self.batch_size = request.batch_size
self.drop_last_batch = request.drop_last_batch
self.torch_criterion = request.torch_criterion
self.amp = amp

Expand Down

0 comments on commit 59ea026

Please sign in to comment.