diff --git a/modyn/selector/internal/selector_strategies/downsampling_strategies/abstract_downsampling_strategy.py b/modyn/selector/internal/selector_strategies/downsampling_strategies/abstract_downsampling_strategy.py index a8dae1333..a71879fbb 100644 --- a/modyn/selector/internal/selector_strategies/downsampling_strategies/abstract_downsampling_strategy.py +++ b/modyn/selector/internal/selector_strategies/downsampling_strategies/abstract_downsampling_strategy.py @@ -43,19 +43,11 @@ def __init__( self.requires_remote_computation = True self.maximum_keys_in_memory = maximum_keys_in_memory self.downsampling_config = downsampling_config - self.status_bar_scale = self._compute_status_bar_scale() - - def _compute_status_bar_scale(self) -> int: - """ - This function is used to create the downsampling status bar and handle the training one accordingly. - - For BTS, we return 100 since the training status bar sees all the samples - For STB, we return the downsampling_ratio since the training status bar sees only a fraction of points - (while the downsampling status bas sees all the points) - """ - if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: - return 100 - return self.downsampling_ratio + # the status bar scale is used in conjunction with the total number of samples (after presampling) + # and the number of already trained samples to show current training progress + # No matter it is BtS or StB, the number of trained samples should be compared to the total number of samples + # divided by the downsampling ratio. Therefore, the status bar scale should be the downsampling ratio. + self.status_bar_scale = self.downsampling_ratio @property def downsampling_params(self) -> dict: diff --git a/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_scheduler.py b/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_scheduler.py index e3783c66b..7b993cb7d 100644 --- a/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_scheduler.py +++ b/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_scheduler.py @@ -102,7 +102,7 @@ def test_switch_functions(): "ratio_max": 100, } assert downs.downsampling_strategy == "RemoteGradNormDownsampling" - assert downs.training_status_bar_scale == 100 + assert downs.training_status_bar_scale == 25 def test_wrong_number_threshold(): @@ -158,7 +158,7 @@ def test_double_threshold(): "ratio_max": 100, } assert downs.downsampling_strategy == "RemoteGradNormDownsampling" - assert downs.training_status_bar_scale == 100 + assert downs.training_status_bar_scale == 25 # above the last threshold for i in range(15, 25): @@ -203,7 +203,7 @@ def test_wrong_trigger(): "ratio_max": 100, } assert downs.downsampling_strategy == "RemoteGradNormDownsampling" - assert downs.training_status_bar_scale == 100 + assert downs.training_status_bar_scale == 25 def test_instantiate_scheduler_just_one(): diff --git a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py index 52527184d..aceb19b3b 100644 --- a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py +++ b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py @@ -10,7 +10,7 @@ from collections import OrderedDict from io import BytesIO from time import sleep -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import grpc import pytest @@ -29,6 +29,7 @@ from modyn.trainer_server.internal.metadata_collector.metadata_collector import MetadataCollector from modyn.trainer_server.internal.trainer.metadata_pytorch_callbacks.base_callback import BaseCallback from modyn.trainer_server.internal.trainer.pytorch_trainer import PytorchTrainer, train +from modyn.trainer_server.internal.trainer.remote_downsamplers import RemoteGradMatchDownsamplingStrategy from modyn.trainer_server.internal.utils.trainer_messages import TrainerMessages from modyn.trainer_server.internal.utils.training_info import TrainingInfo from modyn.utils import DownsamplingMode @@ -117,6 +118,28 @@ def get_mock_label_transformer(): ) +class MockDataloader: + def __init__(self, batch_size, num_batches): + self.batch_size = batch_size + self.num_batches = num_batches + self.dataset = MagicMock() + + def __iter__(self): + return iter( + [ + ( + ("1",) * self.batch_size, + torch.ones(self.batch_size, 10, requires_grad=True), + torch.ones(self.batch_size, dtype=torch.uint8), + ) + for _ in range(self.num_batches) + ] + ) + + def __len__(self): + return self.num_batches + + def mock_get_dataloaders( pipeline_id, trigger_id, @@ -135,12 +158,7 @@ def mock_get_dataloaders( log_path, num_batches: int = 100, ): - mock_train_dataloader = iter( - [ - (("1",) * batch_size, torch.ones(batch_size, 10, requires_grad=True), torch.ones(batch_size, dtype=int)) - for _ in range(num_batches) - ] - ) + mock_train_dataloader = MockDataloader(batch_size, num_batches) return mock_train_dataloader, None @@ -257,6 +275,7 @@ def get_training_info( @patch.object(StorageStub, "__init__", noop_constructor_mock) @patch.object(SelectorStub, "__init__", noop_constructor_mock) +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders) @patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True) @patch( "modyn.trainer_server.internal.dataset.key_sources.selector_key_source.grpc_connection_established", @@ -266,13 +285,13 @@ def get_training_info( @patch("modyn.trainer_server.internal.utils.training_info.dynamic_module_import") @patch("modyn.trainer_server.internal.trainer.pytorch_trainer.dynamic_module_import") @patch.object(PytorchTrainer, "connect_to_selector", return_value=None) -@patch.object(PytorchTrainer, "get_selection_strategy", return_value=(False, "", {})) +@patch.object(PytorchTrainer, "get_selection_strategy") @patch.object(PytorchTrainer, "get_num_samples_in_trigger") @patch.object(SelectorKeySource, "uses_weights", return_value=False) def get_mock_trainer( modyn_config: ModynConfig, - query_queue: mp.Queue, - response_queue: mp.Queue, + query_queue_training: mp.Queue, + response_queue_training: mp.Queue, use_pretrained: bool, load_optimizer_state: bool, pretrained_model_path: pathlib.Path, @@ -289,22 +308,13 @@ def get_mock_trainer( test_grpc_connection_established_selector: MagicMock, test_grpc_connection_established: MagicMock, batch_size: int = 32, - downsampling_mode: DownsamplingMode = DownsamplingMode.DISABLED, - downsampling_ratio: int = 25, - ratio_max: int = 100, + selection_strategy: tuple[bool, str, dict] = (False, "", {}), ): model_dynamic_module_patch.return_value = MockModule(num_optimizers) lr_scheduler_dynamic_module_patch.return_value = MockLRSchedulerModule() mock_get_num_samples.return_value = batch_size * 100 - if downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: - mock_selection_strategy.return_value = ( - True, - "RemoteGradNormDownsampling", - {"downsampling_ratio": downsampling_ratio, "ratio_max": ratio_max, "sample_then_batch": False}, - ) - elif downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH: - raise NotImplementedError() + mock_selection_strategy.return_value = selection_strategy training_info = get_training_info( 0, @@ -323,8 +333,8 @@ def get_mock_trainer( modyn_config.model_dump(by_alias=True), training_info, "cpu", - query_queue, - response_queue, + query_queue_training, + response_queue_training, mp.Queue(), mp.Queue(), logging.getLogger(__name__), @@ -621,7 +631,6 @@ def test_send_model_state_to_server(dummy_system_config: ModynConfig): } -@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders) @patch.object(PytorchTrainer, "weights_handling", return_value=(False, False)) def test_train_invalid_query_message(test_weight_handling, dummy_system_config: ModynConfig): query_status_queue = mp.Queue() @@ -652,7 +661,6 @@ def test_train_invalid_query_message(test_weight_handling, dummy_system_config: # # pylint: disable=too-many-locals -@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders) @patch.object(BaseCallback, "on_train_begin", return_value=None) @patch.object(BaseCallback, "on_train_end", return_value=None) @patch.object(BaseCallback, "on_batch_begin", return_value=None) @@ -870,7 +878,6 @@ def test_create_trainer_with_exception( @pytest.mark.parametrize("downsampling_ratio, ratio_max", [(25, 100), (50, 100), (250, 1000), (125, 1000)]) -@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders) @patch.object(BaseCallback, "on_train_begin", return_value=None) @patch.object(BaseCallback, "on_train_end", return_value=None) @patch.object(BaseCallback, "on_batch_begin", return_value=None) @@ -914,9 +921,11 @@ def test_train_batch_then_sample_accumulation( "custom", False, batch_size=batch_size, - downsampling_mode=DownsamplingMode.BATCH_THEN_SAMPLE, - downsampling_ratio=downsampling_ratio, - ratio_max=ratio_max, + selection_strategy=( + True, + "RemoteGradNormDownsampling", + {"downsampling_ratio": downsampling_ratio, "sample_then_batch": False, "ratio_max": ratio_max}, + ), ) assert trainer._downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE @@ -949,6 +958,7 @@ def mock_forward(data): assert trainer._num_samples == batch_size * num_batches assert trainer._log["num_samples"] == batch_size * num_batches + assert trainer._log["num_batches"] == num_batches # We only train on whole batches, hence we have to scale by batch size assert trainer._log["num_samples_trained"] == ((expected_bts_size * num_batches) // batch_size) * batch_size assert test_on_batch_begin.call_count == len(trainer._callbacks) * num_batches @@ -970,7 +980,6 @@ def mock_forward(data): assert torch.allclose(data, expected_data) -@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders) @patch.object(MetadataCollector, "send_metadata", return_value=None) @patch.object(MetadataCollector, "cleanup", return_value=None) @patch.object(CustomLRScheduler, "step", return_value=None) @@ -1003,3 +1012,130 @@ def test_lr_scheduler_init( ) assert trainer._lr_scheduler.T_max == 100 + + +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.SelectorKeySource") +@patch.object(PytorchTrainer, "get_available_labels_from_selector") +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_per_class_dataloader_from_online_dataset") +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalDatasetWriter") +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalKeySource") +@patch.object(PytorchTrainer, "start_embedding_recording_if_needed") +@patch.object(PytorchTrainer, "end_embedding_recorder_if_needed") +@patch.object(PytorchTrainer, "get_embeddings_if_recorded") +@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_samples") +@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_end_of_current_label") +@patch.object(PytorchTrainer, "update_queue") +def test_downsample_trigger_training_set_label_by_label( + test_update_queue, + test_inform_end_of_current_label, + test_inform_samples, + test_get_embeddings, + test_end_embedding_recording, + test_start_embedding_recording, + test_local_key_source, + test_local_dataset_writer, + test_prepare_per_class_dataloader, + test_get_available_labels, + test_selector_key_source, + dummy_system_config: ModynConfig, +): + batch_size = 4 + available_labels = [0, 1, 2, 3, 4, 5] + test_prepare_per_class_dataloader.return_value = MockDataloader(batch_size, 100) + test_get_available_labels.return_value = available_labels + num_batches = 100 # hardcoded into mock dataloader + query_status_queue_training = mp.Queue() + status_queue_training = mp.Queue() + trainer = get_mock_trainer( + dummy_system_config, + query_status_queue_training, + status_queue_training, + False, + False, + None, + 2, + "custom", + False, + batch_size=batch_size, + selection_strategy=( + True, + "RemoteGradMatchDownsamplingStrategy", + { + "downsampling_ratio": 25, + "downsampling_period": 1, + "sample_then_batch": True, + "balance": True, + "ratio_max": 100, + }, + ), + ) + assert trainer._downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH + assert trainer._downsampler.requires_data_label_by_label + trainer.downsample_trigger_training_set() + assert test_prepare_per_class_dataloader.call_count == 1 + assert test_update_queue.call_count == len(available_labels) * num_batches + 1 + # check the args of the last call + last_call_args = test_update_queue.call_args_list[-1] + expected_batch_number = len(available_labels) * num_batches + expected_num_samples = expected_batch_number * batch_size + assert last_call_args == call("DOWNSAMPLING", expected_batch_number, expected_num_samples, training_active=True) + assert test_inform_end_of_current_label.call_count == len(available_labels) + + +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.SelectorKeySource") +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalDatasetWriter") +@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalKeySource") +@patch.object(PytorchTrainer, "start_embedding_recording_if_needed") +@patch.object(PytorchTrainer, "end_embedding_recorder_if_needed") +@patch.object(PytorchTrainer, "get_embeddings_if_recorded") +@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_samples") +@patch.object(RemoteGradMatchDownsamplingStrategy, "select_points", return_value=([1, 2], torch.ones(2))) +@patch.object(PytorchTrainer, "update_queue") +def test_downsample_trigger_training_set( + test_update_queue, + test_select_points, + test_inform_samples, + test_get_embeddings, + test_end_embedding_recording, + test_start_embedding_recording, + test_local_key_source, + test_local_dataset_writer, + test_selector_key_source, + dummy_system_config: ModynConfig, +): + batch_size = 4 + num_batches = 100 # hardcoded into mock dataloader + query_status_queue_training = mp.Queue() + status_queue_training = mp.Queue() + trainer = get_mock_trainer( + dummy_system_config, + query_status_queue_training, + status_queue_training, + False, + False, + None, + 2, + "custom", + False, + batch_size=batch_size, + selection_strategy=( + True, + "RemoteGradMatchDownsamplingStrategy", + { + "downsampling_ratio": 25, + "downsampling_period": 1, + "sample_then_batch": True, + "balance": False, + "ratio_max": 100, + }, + ), + ) + assert trainer._downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH + assert not trainer._downsampler.requires_data_label_by_label + trainer.downsample_trigger_training_set() + assert test_update_queue.call_count == num_batches + 1 + # check the args of the last call + last_call_args = test_update_queue.call_args_list[-1] + expected_batch_number = num_batches + expected_num_samples = expected_batch_number * batch_size + assert last_call_args == call("DOWNSAMPLING", expected_batch_number, expected_num_samples, training_active=True) diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index b4a755765..82d3da8d6 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -211,7 +211,6 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._info("Handled OnBegin Callbacks.") self._log["epochs"] = [] - batch_number = -1 if self.num_samples_to_pass == 0: epoch_num_generator: Iterable[int] = range(self.epochs_per_trigger) else: @@ -236,30 +235,33 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches batch_accumulator = BatchAccumulator(self._batch_size // post_downsampling_size, self._device) trained_batches = 0 + passed_batches = 0 for epoch in epoch_num_generator: stopw = Stopwatch() # Reset timings per epoch self._log["epochs"].append({}) batch_timings = [] if self._sample_then_batch_this_epoch(epoch): - self.update_queue("TRAINING", batch_number, self._num_samples, training_active=False) + self.update_queue( + "TRAINING", trained_batches, trained_batches * self._batch_size, training_active=False + ) with GPUMeasurement(self._measure_gpu_ops, "DownsampleSTB", self._device, stopw): self.downsample_trigger_training_set() stopw.start("IndivFetchBatch", overwrite=True) stopw.start("FetchBatch", resume=True) - for batch_number, batch in enumerate(self._train_dataloader): + for batch in self._train_dataloader: stopw.stop("FetchBatch") batch_timings.append(stopw.stop("IndivFetchBatch")) retrieve_weights_from_dataloader, weighted_optimization = self.weights_handling(len(batch)) stopw.start("OnBatchBeginCallbacks", resume=True) for _, callback in self._callbacks.items(): - callback.on_batch_begin(self._model.model, self._optimizers, batch, batch_number) + callback.on_batch_begin(self._model.model, self._optimizers, batch, passed_batches) stopw.stop() - self.update_queue("TRAINING", batch_number, self._num_samples, training_active=True) - + self.update_queue("TRAINING", trained_batches, trained_batches * self._batch_size, training_active=True) + passed_batches += 1 with GPUMeasurement(self._measure_gpu_ops, "PreprocessBatch", self._device, stopw, resume=True): sample_ids, target, data = self.preprocess_batch(batch, stopw) @@ -285,6 +287,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches data, sample_ids, target, weights = batch_accumulator.get_accumulated_batch() self._assert_data_size(self._batch_size, data, sample_ids, target) + with GPUMeasurement(self._measure_gpu_ops, "Forward", self._device, stopw, resume=True): output = self._model.model(data) @@ -299,7 +302,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches stopw.start("OnBatchBeforeUpdate", resume=True) for _, callback in self._callbacks.items(): callback.on_batch_before_update( - self._model.model, self._optimizers, batch_number, sample_ids, data, target, output, loss + self._model.model, self._optimizers, trained_batches, sample_ids, data, target, output, loss ) stopw.stop() @@ -315,10 +318,10 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._step_lr_if_necessary(True) - if self._checkpoint_interval > 0 and batch_number % self._checkpoint_interval == 0: + if self._checkpoint_interval > 0 and trained_batches % self._checkpoint_interval == 0: stopw.start("Checkpoint", resume=True) - checkpoint_file_name = self._checkpoint_path / f"model_{batch_number}.modyn" - self.save_state(checkpoint_file_name, batch_number) + checkpoint_file_name = self._checkpoint_path / f"model_{trained_batches}.modyn" + self.save_state(checkpoint_file_name, trained_batches) stopw.stop("Checkpoint") self._num_samples += self._batch_size @@ -326,7 +329,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches stopw.start("OnBatchEnd", resume=True) for _, callback in self._callbacks.items(): callback.on_batch_end( - self._model.model, self._optimizers, batch_number, sample_ids, data, target, output, loss + self._model.model, self._optimizers, trained_batches, sample_ids, data, target, output, loss ) stopw.stop() if 0 < self.num_samples_to_pass <= self._num_samples: @@ -376,10 +379,11 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches total_stopw.stop("TotalTrain") - self._info(f"Finished training: {self._num_samples} samples, {batch_number + 1} batches.") + self._info(f"Finished training: {self._num_samples} samples, {passed_batches} batches.") self._log["num_samples"] = self._num_samples self._log["num_samples_trained"] = trained_batches * self._batch_size - self._log["num_batches"] = batch_number + 1 + self._log["num_batches"] = passed_batches + self._log["num_batches_trained"] = trained_batches self._log["total_train"] = total_stopw.measurements.get("TotalTrain", 0) self._assert_training_size(epoch, trained_batches) @@ -387,7 +391,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._persist_pipeline_log() for _, callback in self._callbacks.items(): - callback.on_train_end(self._model.model, self._optimizers, self._num_samples, batch_number) + callback.on_train_end(self._model.model, self._optimizers, self._num_samples, passed_batches) for metric in self._callbacks: self._metadata_collector.send_metadata(metric) @@ -435,7 +439,7 @@ def downsample_trigger_training_set(self) -> None: available_labels = self.get_available_labels_from_selector() number_of_samples = 0 - batch_number = 0 + batch_number = -1 first_label = True for label in available_labels: if first_label: @@ -480,7 +484,7 @@ def downsample_trigger_training_set(self) -> None: ) self._train_dataloader.dataset.change_key_source(new_key_source) - self.update_queue("DOWNSAMPLING", batch_number, number_of_samples, training_active=True) + self.update_queue("DOWNSAMPLING", batch_number + 1, number_of_samples, training_active=True) # set the model to train self._model.model.train() @@ -863,16 +867,16 @@ def _sample_then_batch_this_epoch(self, epoch: int) -> bool: def _iterate_dataloader_and_compute_scores( self, dataloader: torch.utils.data.DataLoader, - previous_batch_number: int = 0, + previous_batch_number: int = -1, previous_number_of_samples: int = 0, ) -> Tuple[int, int]: """ Function to iterate a dataloader, compute the forward pass and send the forward output to the downsampler. Args: dataloader: torch.dataloader to get the data - previous_batch_number: number of batches processed before calling this function. Useful when this function - is called several times to keep track of previous invocations (ex label by label dataloader). We need to - have a total to correctly update the queue and show the progress in the supervisor counter. + previous_batch_number: The batch number returned from the last call to this method. Useful when this + function is called several times to keep track of previous invocations (ex label by label dataloader). We + need to have a total to correctly update the queue and show the progress in the supervisor counter. previous_number_of_samples: number of samples processed before calling this function. See above for the use. Returns: @@ -880,9 +884,9 @@ def _iterate_dataloader_and_compute_scores( """ number_of_samples = previous_number_of_samples batch_number = previous_batch_number - for batch_number, batch in enumerate(dataloader): + for batch in dataloader: self.update_queue("DOWNSAMPLING", batch_number, number_of_samples, training_active=False) - + batch_number += 1 sample_ids, target, data = self.preprocess_batch(batch) number_of_samples += len(sample_ids)