From cdf12802e33be2917d21ddca6825b111c61f6dc9 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Mon, 1 Jul 2024 14:23:38 +0200 Subject: [PATCH] Twin RHO Model Step 2: split the training set and train the twin model (#552) Per the title. The major change is in file [modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py](https://github.com/eth-easl/modyn/pull/552/files#diff-1628acbf60458002fa86b1b4ad8b71fb21deedc12390e21bc4620d6197cdab04) --- .../supervisor/integrationtest_supervisor.py | 5 - .../pipeline/sampling/downsampling_config.py | 8 + .../models/selector_state_metadata.py | 6 + .../rho_loss_twin_model.py | 9 +- .../rho_loss_downsampling_strategy.py | 127 +++++-- .../test_rho_loss_downsampling_strategy.py | 336 +++++++++++++----- .../test_irreducible_loss_producer.py | 2 +- .../irreducible_loss_producer.py | 2 +- 8 files changed, 354 insertions(+), 141 deletions(-) diff --git a/integrationtests/supervisor/integrationtest_supervisor.py b/integrationtests/supervisor/integrationtest_supervisor.py index 7560afe12..52e6292cf 100644 --- a/integrationtests/supervisor/integrationtest_supervisor.py +++ b/integrationtests/supervisor/integrationtest_supervisor.py @@ -139,11 +139,6 @@ def test_rho_loss_pipeline_with_two_triggers() -> None: model_ids = [model.model_id for model in models] assert latest_il_model_id == max(model_ids) - # validate the holdout set is in correct proportion - # in rho_loss.yaml, each trigger contains 10 images. The holdout set ratio is 30% - # so the holdout set should contain 3 images - assert all(trigger.num_keys == 3 for trigger in triggers) - if __name__ == "__main__": tiny_dataset_helper = TinyDatasetHelper(dataset_size=20) diff --git a/modyn/config/schema/pipeline/sampling/downsampling_config.py b/modyn/config/schema/pipeline/sampling/downsampling_config.py index 0dc57dadb..2b1a891e1 100644 --- a/modyn/config/schema/pipeline/sampling/downsampling_config.py +++ b/modyn/config/schema/pipeline/sampling/downsampling_config.py @@ -165,12 +165,20 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig): min=0, max=100, ) + holdout_set_ratio_max: int = Field( + description="Reference maximum holdout_set_ratio value. Defaults to 100, which implies percent." + " If you set this to 1000, holdout_set_ratio describes promille instead.", + default=100, + min=1, + ) il_training_config: ILTrainingConfig = Field(description="The configuration for the IL training.") @model_validator(mode="after") def validate_holdout_set_ratio(self) -> Self: if self.holdout_set_strategy == "Twin" and self.holdout_set_ratio != 50: raise ValueError("holdout_set_ratio should be 100 for the Twin strategy.") + if self.holdout_set_ratio > self.holdout_set_ratio_max: + raise ValueError("holdout_set_ratio cannot be greater than holdout_set_ratio_max.") return self diff --git a/modyn/metadata_database/models/selector_state_metadata.py b/modyn/metadata_database/models/selector_state_metadata.py index ce14bf0be..c581e92cb 100644 --- a/modyn/metadata_database/models/selector_state_metadata.py +++ b/modyn/metadata_database/models/selector_state_metadata.py @@ -22,6 +22,12 @@ class SelectorStateMetadataMixin: sample_key = Column("sample_key", BIGINT, primary_key=True) seen_in_trigger_id = Column("seen_in_trigger_id", Integer, primary_key=True) used = Column("used", Boolean, default=False) + # This is a field to help selection strategies to track the state of the samples during preparing the + # post-presampling training set. It should be reset to 0 after the training set is prepared. + # For example, the RHOLossDownsamplingStrategy uses this field to split the training set into two halves + # by first sampling 50% and setting their `tmp_version` field to 1, + # Then it generates the remaining 50% by querying samples with `tmp_version` as 0. + tmp_version = Column("tmp_version", Integer, default=0) timestamp = Column("timestamp", BigInteger) label = Column("label", Integer) last_used_in_trigger = Column("last_used_in_trigger", Integer, default=-1) diff --git a/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py b/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py index 7037d2c29..64485f0b6 100644 --- a/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py +++ b/modyn/models/rho_loss_twin_model/rho_loss_twin_model.py @@ -19,7 +19,6 @@ class RHOLOSSTwinModelModyn(nn.Module): def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: super().__init__() - self.device = device model_module = dynamic_module_import("modyn.models") rho_model_class = model_configuration["rho_real_model_class"] rho_model_config = model_configuration["rho_real_model_config"] @@ -59,11 +58,11 @@ def _training_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch. return self._models[self._current_model](data) def _eval_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor: - seen_by_model0 = torch.BoolTensor( - [sample_id in self._models_seen_ids[0] for sample_id in sample_ids], device=self.device + seen_by_model0 = torch.tensor( + [sample_id in self._models_seen_ids[0] for sample_id in sample_ids], dtype=torch.bool, device=data.device ) - seen_by_model1 = torch.BoolTensor( - [sample_id in self._models_seen_ids[1] for sample_id in sample_ids], device=self.device + seen_by_model1 = torch.tensor( + [sample_id in self._models_seen_ids[1] for sample_id in sample_ids], dtype=torch.bool, device=data.device ) # if model 0 did not see any sample, we route all samples to model 0 diff --git a/modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py b/modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py index 79771eb6a..7077489cb 100644 --- a/modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py +++ b/modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py @@ -9,10 +9,10 @@ from modyn.metadata_database.utils import ModelStorageStrategyConfig from modyn.selector.internal.selector_strategies import AbstractSelectionStrategy from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy -from modyn.selector.internal.selector_strategies.utils import get_trigger_dataset_size from modyn.selector.internal.storage_backend import AbstractStorageBackend from modyn.selector.internal.storage_backend.database import DatabaseStorageBackend from sqlalchemy import Select, func, select +from sqlalchemy.orm.session import Session logger = logging.getLogger(__name__) @@ -34,6 +34,7 @@ def __init__( maximum_keys_in_memory=maximum_keys_in_memory, tail_triggers=0 ) self.holdout_set_ratio = downsampling_config.holdout_set_ratio + self.holdout_set_ratio_max = downsampling_config.holdout_set_ratio_max self.holdout_set_strategy = downsampling_config.holdout_set_strategy self.il_training_config = downsampling_config.il_training_config self.grpc = TrainerServerGRPCHandlerMixin(modyn_config) @@ -46,24 +47,80 @@ def __init__( def inform_next_trigger(self, next_trigger_id: int, selector_storage_backend: AbstractStorageBackend) -> None: if not isinstance(selector_storage_backend, DatabaseStorageBackend): raise ValueError("RHOLossDownsamplingStrategy requires a DatabaseStorageBackend") + + probability = self.holdout_set_ratio / self.holdout_set_ratio_max + + query = self._get_sampling_query(self._pipeline_id, next_trigger_id, probability, selector_storage_backend) + rho_state = self._get_latest_rho_state(self.rho_pipeline_id, self._modyn_config) + if rho_state is not None: + next_rho_trigger_id = rho_state[0] + 1 + last_model_id = rho_state[1] + else: + logger.info(f"No previous state found for pipeline {self.rho_pipeline_id}.") + next_rho_trigger_id = 0 + last_model_id = None + self._persist_holdout_set(query, next_rho_trigger_id, selector_storage_backend) + if self.il_training_config.use_previous_model: + if self.holdout_set_strategy == "Twin": + raise NotImplementedError("Use previous model currently is not supported for Twin strategy") + previous_model_id = last_model_id + else: + previous_model_id = None + + model_id = self._train_il_model(next_rho_trigger_id, previous_model_id) + logger.info( + f"Trained IL model {model_id} for trigger {next_rho_trigger_id} in rho pipeline" + f"{self.rho_pipeline_id} with rho trigger id {next_trigger_id}." + ) if self.holdout_set_strategy == "Twin": - raise NotImplementedError("Twin holdout set strategy is not implemented yet.") - self._prepare_holdout_set(next_trigger_id, selector_storage_backend) - self._train_il_model(next_trigger_id) + second_next_trigger_id = next_rho_trigger_id + 1 + second_query = self._get_rest_data_query(self._pipeline_id, next_trigger_id) + self._persist_holdout_set(second_query, second_next_trigger_id, selector_storage_backend) + self._train_il_model(second_next_trigger_id, model_id) + logger.info( + f"Twin strategy: Trained second IL model for trigger {next_trigger_id} in rho pipeline " + f"{self.rho_pipeline_id} with rho trigger id {next_trigger_id}." + ) + self._clean_tmp_version(self._pipeline_id, next_trigger_id, selector_storage_backend) + + @staticmethod + def _clean_tmp_version( + main_pipeline_id: int, trigger_id: int, selector_storage_backend: AbstractStorageBackend + ) -> None: + assert isinstance(selector_storage_backend, DatabaseStorageBackend) + + def _session_callback(session: Session) -> None: + session.query(SelectorStateMetadata).filter( + SelectorStateMetadata.pipeline_id == main_pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == trigger_id, + ).update({"tmp_version": 0}) + session.commit() + + selector_storage_backend._execute_on_session(_session_callback) + + @staticmethod + def _get_rest_data_query(main_pipeline_id: int, trigger_id: int) -> Select: + stmt = select(SelectorStateMetadata.sample_key).filter( + SelectorStateMetadata.pipeline_id == main_pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == trigger_id, + SelectorStateMetadata.tmp_version == 0, + ) + + return stmt @property def downsampling_params(self) -> dict: config = super().downsampling_params config["rho_pipeline_id"] = self.rho_pipeline_id - il_model_id = self._get_latest_il_model_id(self.rho_pipeline_id, self._modyn_config) - assert il_model_id is not None - config["il_model_id"] = il_model_id + state = self._get_latest_rho_state(self.rho_pipeline_id, self._modyn_config) + assert state is not None + config["il_model_id"] = state[1] return config @staticmethod - def _get_latest_il_model_id(rho_pipeline_id: int, modyn_config: dict) -> Optional[int]: + def _get_latest_rho_state(rho_pipeline_id: int, modyn_config: dict) -> Optional[Tuple[int, int]]: with MetadataDatabaseConnection(modyn_config) as database: - # find the maximal trigger id. This is the current trigger id. + # find the maximal trigger id max_trigger_id = ( database.session.query(func.max(Trigger.trigger_id)) .filter(Trigger.pipeline_id == rho_pipeline_id) @@ -79,13 +136,9 @@ def _get_latest_il_model_id(rho_pipeline_id: int, modyn_config: dict) -> Optiona .scalar() ) assert il_model_id is not None - return il_model_id + return max_trigger_id, il_model_id - def _train_il_model(self, trigger_id: int) -> int: - if self.il_training_config.use_previous_model: - previous_model_id = self._get_latest_il_model_id(self.rho_pipeline_id, self._modyn_config) - else: - previous_model_id = None + def _train_il_model(self, trigger_id: int, previous_model_id: Optional[int]) -> int: training_id = self.grpc.start_training( pipeline_id=self.rho_pipeline_id, trigger_id=trigger_id, @@ -135,16 +188,10 @@ def _create_rho_pipeline_id(self, database: MetadataDatabaseConnection, data_con ) return rho_pipeline_id - def _prepare_holdout_set(self, next_trigger_id: int, selector_storage_backend: AbstractStorageBackend) -> None: - current_trigger_dataset_size = get_trigger_dataset_size( - selector_storage_backend, self._pipeline_id, next_trigger_id, tail_triggers=0 - ) - - holdout_set_size = max(int(current_trigger_dataset_size * self.holdout_set_ratio / 100), 1) - - stmt = self._get_holdout_sampling_query(self._pipeline_id, next_trigger_id, holdout_set_size).execution_options( - yield_per=self.maximum_keys_in_memory - ) + def _persist_holdout_set( + self, query: Select, target_trigger_id: int, selector_storage_backend: AbstractStorageBackend + ) -> None: + stmt = query.execution_options(yield_per=self.maximum_keys_in_memory) def training_set_producer() -> Iterable[tuple[list[tuple[int, float]], dict[str, Any]]]: with MetadataDatabaseConnection(self._modyn_config) as database: @@ -154,24 +201,36 @@ def training_set_producer() -> Iterable[tuple[list[tuple[int, float]], dict[str, total_keys_in_trigger, *_ = AbstractSelectionStrategy.store_training_set( self.rho_pipeline_id, - next_trigger_id, + target_trigger_id, self._modyn_config, training_set_producer, selector_storage_backend.insertion_threads, ) logger.info( - f"Stored {total_keys_in_trigger} keys in the holdout set for trigger {next_trigger_id} " + f"Stored {total_keys_in_trigger} keys in the holdout set for trigger {target_trigger_id} " f"in rho pipeline {self.rho_pipeline_id}" ) @staticmethod - def _get_holdout_sampling_query(main_pipeline_id: int, trigger_id: int, target_size: int) -> Select: - return ( - select(SelectorStateMetadata.sample_key) - .filter( + def _get_sampling_query( + main_pipeline_id: int, trigger_id: int, probability: float, selector_storage_backend: AbstractStorageBackend + ) -> Select: + assert isinstance(selector_storage_backend, DatabaseStorageBackend) + + def _session_callback(session: Session) -> None: + session.query(SelectorStateMetadata).filter( SelectorStateMetadata.pipeline_id == main_pipeline_id, SelectorStateMetadata.seen_in_trigger_id == trigger_id, - ) - .order_by(func.random()) # pylint: disable=E1102 - .limit(target_size) + func.random() < probability, # pylint: disable=not-callable + ).update({"tmp_version": 1}) + session.commit() + + selector_storage_backend._execute_on_session(_session_callback) + + stmt = select(SelectorStateMetadata.sample_key).filter( + SelectorStateMetadata.pipeline_id == main_pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == trigger_id, + SelectorStateMetadata.tmp_version == 1, ) + + return stmt diff --git a/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_rho_loss_downsampling_strategy.py b/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_rho_loss_downsampling_strategy.py index 3d1241a6b..efe02ba38 100644 --- a/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_rho_loss_downsampling_strategy.py +++ b/modyn/tests/selector/internal/selector_strategies/downsampling_strategies/test_rho_loss_downsampling_strategy.py @@ -3,8 +3,8 @@ import pathlib import shutil import tempfile -from typing import List, Literal, Optional, Tuple -from unittest.mock import ANY, MagicMock, patch +from typing import Any, Callable, List, Literal, Optional, Tuple +from unittest.mock import ANY, MagicMock, Mock, call, patch import pytest from modyn.common.grpc.grpc_helpers import TrainerServerGRPCHandlerMixin @@ -26,6 +26,7 @@ ) from modyn.tests.selector.internal.storage_backend.utils import MockStorageBackend from pydantic import TypeAdapter +from sqlalchemy import select database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db" @@ -93,7 +94,9 @@ def noop_init_trainer_server(self): return -def store_samples(pipeline_id: int, trigger_id: int, key_ts_label_tuples: List[Tuple[int, int, int]]) -> None: +def store_samples( + pipeline_id: int, trigger_id: int, key_ts_label_tuples: List[Tuple[int, int, int]], tmp_version=0 +) -> None: with MetadataDatabaseConnection(get_minimal_modyn_config()) as database: for key, timestamp, label in key_ts_label_tuples: database.session.add( @@ -103,6 +106,7 @@ def store_samples(pipeline_id: int, trigger_id: int, key_ts_label_tuples: List[T timestamp=timestamp, label=label, seen_in_trigger_id=trigger_id, + tmp_version=tmp_version, ) ) database.session.commit() @@ -126,12 +130,7 @@ def register_pipeline(auxiliary_pipeline_id: Optional[int], data_config: DataCon @patch.object(TrainerServerGRPCHandlerMixin, "init_trainer_server", noop_init_trainer_server) @patch.object(AbstractSelectionStrategy, "store_training_set", return_value=(42, 42, {})) -@patch( - "modyn.selector.internal.selector_strategies.downsampling_strategies.rho_loss_downsampling_strategy" - ".get_trigger_dataset_size" -) -def test__prepare_holdout_set( - mock_get_trigger_dataset_size, +def test__persist_holdout_set( mock_store_training_set, il_training_config: ILTrainingConfig, data_config: DataConfig, @@ -147,70 +146,56 @@ def test__prepare_holdout_set( holdout_set_strategy="Simple", ) maximum_keys_in_memory = 4 - trigger_id2dataset_size = [13, 24, 5] - - trigger_id2range = [(0, 13), (13, 37), (37, 42)] - store_samples( - pipeline_id=pipeline_id, - trigger_id=0, - key_ts_label_tuples=[(i, i, 0) for i in range(*trigger_id2range[0])], - ) - + trigger_id = 3 + dataset_range = (13, 37) store_samples( pipeline_id=pipeline_id, - trigger_id=1, - key_ts_label_tuples=[(i, i, 0) for i in range(*trigger_id2range[1])], - ) - - store_samples( - pipeline_id=pipeline_id, - trigger_id=2, - key_ts_label_tuples=[(i, i, 0) for i in range(*trigger_id2range[2])], + trigger_id=trigger_id, + key_ts_label_tuples=[(i, i, 0) for i in range(*dataset_range)], ) strategy = RHOLossDownsamplingStrategy(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) rho_pipeline_id = strategy.rho_pipeline_id storage_backend = MockStorageBackend(pipeline_id, modyn_config, maximum_keys_in_memory) - def validate_training_set_producer(producer, trigger_id): - chunks = list(producer()) - # verify the partition size - for chunk_id, (chunk, _) in enumerate(chunks): - # only the last chunk can have less than maximum_keys_in_memory number of samples - if chunk_id == len(chunks) - 1: - assert len(chunk) <= maximum_keys_in_memory - else: - assert len(chunk) == maximum_keys_in_memory - # verify the number of partitions - # expected value: ceil(floor(trigger_dataset_size / holdout_set_ratio) / maximum_keys_in_memory) - expected_num_partitions = [2, 3, 1] - assert len(chunks) == expected_num_partitions[trigger_id] - # verify the samples - samples = [sample for (chunk, _) in chunks for sample in chunk] - expected_num_samples = [6, 12, 2] - assert len(samples) == expected_num_samples[trigger_id], "Number of samples is not as expected." - for sample in samples: - # verify key - assert trigger_id2range[trigger_id][0] <= sample[0] < trigger_id2range[trigger_id][1] - # verify weight - assert sample[1] == pytest.approx(1.0) - - for trigger_id in range(3): - mock_get_trigger_dataset_size.reset_mock() - mock_store_training_set.reset_mock() - mock_get_trigger_dataset_size.return_value = trigger_id2dataset_size[trigger_id] - - strategy._prepare_holdout_set(trigger_id, storage_backend) - mock_get_trigger_dataset_size.assert_called_once_with(storage_backend, pipeline_id, trigger_id, tail_triggers=0) - mock_store_training_set.assert_called_once_with( - rho_pipeline_id, - trigger_id, - modyn_config, - ANY, - ANY, + mock_store_training_set.reset_mock() + test_query = ( + select(SelectorStateMetadata.sample_key) + .filter( + SelectorStateMetadata.pipeline_id == pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == trigger_id, ) - training_set_producer = mock_store_training_set.call_args[0][3] - validate_training_set_producer(training_set_producer, trigger_id) + .limit(11) + ) + strategy._persist_holdout_set(test_query, trigger_id, storage_backend) + mock_store_training_set.assert_called_once_with( + rho_pipeline_id, + trigger_id, + modyn_config, + ANY, + ANY, + ) + producer = mock_store_training_set.call_args[0][3] + + chunks = list(producer()) + # verify the partition size + for chunk_id, (chunk, _) in enumerate(chunks): + # only the last chunk can have less than maximum_keys_in_memory number of samples + if chunk_id == len(chunks) - 1: + assert len(chunk) <= maximum_keys_in_memory + else: + assert len(chunk) == maximum_keys_in_memory + # verify the number of partitions + # expected value: ceil(floor(trigger_dataset_size / holdout_set_ratio) / maximum_keys_in_memory) + assert len(chunks) == 3 + # verify the samples + samples = [sample for (chunk, _) in chunks for sample in chunk] + assert len(samples) == 11 + for sample in samples: + # verify key + assert dataset_range[0] <= sample[0] < dataset_range[1] + # verify weight + assert sample[1] == pytest.approx(1.0) @patch.object(TrainerServerGRPCHandlerMixin, "init_trainer_server", noop_init_trainer_server) @@ -326,26 +311,21 @@ def test_downsampling_params(il_training_config: ILTrainingConfig, data_config: assert strategy.downsampling_params == expected -@pytest.mark.parametrize("use_previous_model", [True, False]) @pytest.mark.parametrize("previous_model_id", [None, 21]) @patch.object(TrainerServerGRPCHandlerMixin, "start_training", return_value=42) @patch.object(TrainerServerGRPCHandlerMixin, "wait_for_training_completion") @patch.object(TrainerServerGRPCHandlerMixin, "store_trained_model", return_value=33) @patch.object(TrainerServerGRPCHandlerMixin, "init_trainer_server", noop_init_trainer_server) -@patch.object(RHOLossDownsamplingStrategy, "_get_latest_il_model_id") def test__train_il_model( - mock_get_latest_il_model_id: MagicMock, mock_store_trained_model: MagicMock, mock_wait_for_training_completion: MagicMock, mock_start_training: MagicMock, il_training_config: ILTrainingConfig, data_config: DataConfig, previous_model_id: Optional[int], - use_previous_model: bool, ): pipeline_id = register_pipeline(None, data_config) - il_training_config.use_previous_model = use_previous_model - mock_get_latest_il_model_id.return_value = previous_model_id + il_training_config.use_previous_model = previous_model_id is not None modyn_config = get_minimal_modyn_config() downsampling_config = RHOLossDownsamplingConfig( ratio=60, @@ -355,25 +335,16 @@ def test__train_il_model( ) maximum_keys_in_memory = 4 - if use_previous_model: - expected_previous_model_id = previous_model_id - else: - # no matter what the previous_model_id is, it should not be used - expected_previous_model_id = None strategy = RHOLossDownsamplingStrategy(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) trigger_id = 1 - model_id = strategy._train_il_model(trigger_id) + model_id = strategy._train_il_model(trigger_id, previous_model_id) mock_start_training.assert_called_once_with( pipeline_id=strategy.rho_pipeline_id, trigger_id=trigger_id, training_config=il_training_config, data_config=data_config, - previous_model_id=expected_previous_model_id, + previous_model_id=previous_model_id, ) - if use_previous_model: - mock_get_latest_il_model_id.assert_called_once_with(strategy.rho_pipeline_id, modyn_config) - else: - mock_get_latest_il_model_id.assert_not_called() mock_wait_for_training_completion.assert_called_once_with(mock_start_training.return_value) mock_store_trained_model.assert_called_once_with(mock_start_training.return_value) assert model_id == mock_store_trained_model.return_value @@ -383,38 +354,213 @@ def test__train_il_model( "modyn.selector.internal.selector_strategies.downsampling_strategies.rho_loss_downsampling_strategy.isinstance", return_value=True, ) +@patch.object(RHOLossDownsamplingStrategy, "_get_latest_rho_state") +@patch.object(RHOLossDownsamplingStrategy, "_train_il_model", return_value=42) +@patch.object(RHOLossDownsamplingStrategy, "_get_sampling_query") +@patch.object(RHOLossDownsamplingStrategy, "_persist_holdout_set") +@patch.object(RHOLossDownsamplingStrategy, "_clean_tmp_version") @patch.object(TrainerServerGRPCHandlerMixin, "init_trainer_server", noop_init_trainer_server) -def test_inform_next_trigger( - mock_is_instance: MagicMock, il_training_config: ILTrainingConfig, data_config: DataConfig +@pytest.mark.parametrize("use_previous_model", [True, False]) +@pytest.mark.parametrize("rho_state", [None, (1, 2)]) +def test_inform_next_trigger_simple( + mock__clean_tmp_version: MagicMock, + mock__persist_holdout_set: MagicMock, + mock__get_sampling_query: MagicMock, + mock__train_il_model: MagicMock, + mock_get_latest_rho_state: MagicMock, + mock_is_instance: MagicMock, + rho_state: Optional[Tuple[int, int]], + use_previous_model: bool, + il_training_config: ILTrainingConfig, + data_config: DataConfig, ): pipeline_id = register_pipeline(None, data_config) - + il_training_config.use_previous_model = use_previous_model + mock_get_latest_rho_state.return_value = rho_state modyn_config = get_minimal_modyn_config() downsampling_config = RHOLossDownsamplingConfig( ratio=60, - holdout_set_ratio=50, + holdout_set_ratio=10, il_training_config=il_training_config, holdout_set_strategy="Simple", ) maximum_keys_in_memory = 4 + mock_query = MagicMock() + mock__get_sampling_query.return_value = mock_query strategy = RHOLossDownsamplingStrategy(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) - strategy._prepare_holdout_set = MagicMock() - strategy._train_il_model = MagicMock(return_value=42) + next_trigger_id = 2 + storage_backend = MockStorageBackend(pipeline_id, modyn_config, maximum_keys_in_memory) + strategy.inform_next_trigger(next_trigger_id, storage_backend) - trigger_id = 1 + mock_get_latest_rho_state.assert_called_once_with(strategy.rho_pipeline_id, modyn_config) + if use_previous_model: + expected_previous_model_id = rho_state[1] if rho_state is not None else None + else: + expected_previous_model_id = None + expected_rho_next_trigger_id = rho_state[0] + 1 if rho_state is not None else 0 + mock__persist_holdout_set.assert_called_once_with(mock_query, expected_rho_next_trigger_id, ANY) + mock__train_il_model.assert_called_once_with(expected_rho_next_trigger_id, expected_previous_model_id) + mock__clean_tmp_version.assert_called_once_with(pipeline_id, next_trigger_id, ANY) + mock__get_sampling_query.assert_called_once_with(pipeline_id, next_trigger_id, pytest.approx(0.1), ANY) + + +@patch( + "modyn.selector.internal.selector_strategies.downsampling_strategies.rho_loss_downsampling_strategy.isinstance", + return_value=True, +) +@patch.object(RHOLossDownsamplingStrategy, "_get_latest_rho_state") +@patch.object(RHOLossDownsamplingStrategy, "_train_il_model", return_value=42) +@patch.object(RHOLossDownsamplingStrategy, "_get_sampling_query") +@patch.object(RHOLossDownsamplingStrategy, "_persist_holdout_set") +@patch.object(RHOLossDownsamplingStrategy, "_clean_tmp_version") +@patch.object(RHOLossDownsamplingStrategy, "_get_rest_data_query") +@patch.object(TrainerServerGRPCHandlerMixin, "init_trainer_server", noop_init_trainer_server) +@pytest.mark.parametrize("rho_state", [None, (1, 2)]) +def test_inform_next_trigger_twin( + mock__get_rest_data_query: MagicMock, + mock__clean_tmp_version: MagicMock, + mock__persist_holdout_set: MagicMock, + mock__get_sampling_query: MagicMock, + mock__train_il_model: MagicMock, + mock_get_latest_rho_state: MagicMock, + mock_is_instance: MagicMock, + rho_state: Optional[Tuple[int, int]], + il_training_config: ILTrainingConfig, + data_config: DataConfig, +): + pipeline_id = register_pipeline(None, data_config) + il_training_config.use_previous_model = False + modyn_config = get_minimal_modyn_config() + downsampling_config = RHOLossDownsamplingConfig( + ratio=60, + holdout_set_ratio=50, + il_training_config=il_training_config, + holdout_set_strategy="Twin", + ) + maximum_keys_in_memory = 4 + mock_query = MagicMock() + mock__get_sampling_query.return_value = mock_query + mock_second_query = MagicMock() + mock__get_rest_data_query.return_value = mock_second_query + mock__train_il_model.side_effect = [1, 2] + mock_get_latest_rho_state.return_value = rho_state + + strategy = RHOLossDownsamplingStrategy(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) + next_trigger_id = 2 storage_backend = MockStorageBackend(pipeline_id, modyn_config, maximum_keys_in_memory) - strategy.inform_next_trigger(trigger_id, storage_backend) + strategy.inform_next_trigger(next_trigger_id, storage_backend) - strategy._prepare_holdout_set.assert_called_once_with(trigger_id, storage_backend) - strategy._train_il_model.assert_called_once_with(trigger_id) + mock__get_sampling_query.assert_called_once_with(pipeline_id, next_trigger_id, pytest.approx(0.5), ANY) + mock__get_rest_data_query.assert_called_once_with(pipeline_id, next_trigger_id) + mock_get_latest_rho_state.assert_called_once_with(strategy.rho_pipeline_id, modyn_config) + rho_trigger_ids = [0, 1] if rho_state is None else [rho_state[0] + 1, rho_state[0] + 2] + mock__train_il_model.assert_has_calls([call(rho_trigger_ids[0], None), call(rho_trigger_ids[1], 1)]) + mock__persist_holdout_set.assert_has_calls( + [call(mock_query, rho_trigger_ids[0], ANY), call(mock_second_query, rho_trigger_ids[1], ANY)] + ) + mock__clean_tmp_version.assert_called_once_with(pipeline_id, next_trigger_id, ANY) -def test__get_latest_il_model_id(): + +def test__get_latest_rho_state(): modyn_config = get_minimal_modyn_config() rho_pipeline_id = 1 - assert RHOLossDownsamplingStrategy._get_latest_il_model_id(rho_pipeline_id, modyn_config) is None + assert RHOLossDownsamplingStrategy._get_latest_rho_state(rho_pipeline_id, modyn_config) is None add_trigger_and_model(rho_pipeline_id, 0) - assert RHOLossDownsamplingStrategy._get_latest_il_model_id(rho_pipeline_id, modyn_config) == 1 + assert RHOLossDownsamplingStrategy._get_latest_rho_state(rho_pipeline_id, modyn_config) == (0, 1) add_trigger_and_model(rho_pipeline_id, 1) - assert RHOLossDownsamplingStrategy._get_latest_il_model_id(rho_pipeline_id, modyn_config) == 2 + assert RHOLossDownsamplingStrategy._get_latest_rho_state(rho_pipeline_id, modyn_config) == (1, 2) + + +@patch( + "modyn.selector.internal.selector_strategies.downsampling_strategies.rho_loss_downsampling_strategy.isinstance", + return_value=True, +) +def test__clean_tmp_version(mock_is_instance, data_config: DataConfig): + modyn_config = get_minimal_modyn_config() + + def mock_storage_backend_execute_on_session_patch(session_callback: Callable) -> Any: + with MetadataDatabaseConnection(modyn_config) as database: + return session_callback(database.session) + + pipeline_id = register_pipeline(None, data_config) + trigger_id = 2 + store_samples(pipeline_id, trigger_id, [(i, i, 0) for i in range(10, 20)], tmp_version=1) + + mock_storage_backend = MockStorageBackend(pipeline_id, modyn_config, 4) + mock_storage_backend._execute_on_session = Mock(wraps=mock_storage_backend_execute_on_session_patch) + RHOLossDownsamplingStrategy._clean_tmp_version(pipeline_id, trigger_id, mock_storage_backend) + mock_storage_backend._execute_on_session.assert_called_once_with(ANY) + with MetadataDatabaseConnection(modyn_config) as database: + + assert ( + database.session.query(SelectorStateMetadata) + .filter( + SelectorStateMetadata.tmp_version == 1, + SelectorStateMetadata.pipeline_id == pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == trigger_id, + ) + .count() + ) == 0 + + +def test__get_rest_data_query(data_config: DataConfig): + pipeline_id = register_pipeline(None, data_config) + trigger_id = 2 + key_ts_label_tmp_version_tuples = [ + (1, 1, 0, 0), + (2, 2, 0, 1), + (3, 3, 0, 0), + (4, 4, 0, 1), + (5, 5, 0, 1), + (6, 6, 0, 0), + (7, 7, 0, 1), + (8, 8, 0, 0), + (9, 9, 0, 0), + (10, 10, 0, 1), + ] + modyn_config = get_minimal_modyn_config() + with MetadataDatabaseConnection(modyn_config) as database: + for key, timestamp, label, version in key_ts_label_tmp_version_tuples: + database.session.add( + SelectorStateMetadata( + pipeline_id=pipeline_id, + sample_key=key, + timestamp=timestamp, + label=label, + seen_in_trigger_id=trigger_id, + tmp_version=version, + ) + ) + database.session.commit() + + query = RHOLossDownsamplingStrategy._get_rest_data_query(pipeline_id, trigger_id) + with MetadataDatabaseConnection(modyn_config) as database: + samples = database.session.execute(query).fetchall() + assert sorted(samples) == [(1,), (3,), (6,), (8,), (9,)] + + +@patch( + "modyn.selector.internal.selector_strategies.downsampling_strategies.rho_loss_downsampling_strategy.isinstance", + return_value=True, +) +def test__get_sampling_query(mock_is_instance, data_config: DataConfig): + modyn_config = get_minimal_modyn_config() + + def mock_storage_backend_execute_on_session_patch(session_callback: Callable) -> Any: + with MetadataDatabaseConnection(modyn_config) as database: + return session_callback(database.session) + + pipeline_id = register_pipeline(None, data_config) + trigger_id = 2 + full_dataset_size = 1000 + store_samples(pipeline_id, trigger_id, [(i, i, 0) for i in range(full_dataset_size)]) + storage_backend = MockStorageBackend(pipeline_id, get_minimal_modyn_config(), 4) + storage_backend._execute_on_session = Mock(wraps=mock_storage_backend_execute_on_session_patch) + query = RHOLossDownsamplingStrategy._get_sampling_query(pipeline_id, trigger_id, 0.5, storage_backend) + with MetadataDatabaseConnection(modyn_config) as database: + samples = database.session.execute(query).fetchall() + # with a dataset size as large as 1000, let's hope this range would make it not so flaky + assert 450 <= len(samples) <= 550 + assert storage_backend._execute_on_session.call_count == 1 diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/test_irreducible_loss_producer.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/test_irreducible_loss_producer.py index e1c0ae270..f9058a226 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/test_irreducible_loss_producer.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/test_irreducible_loss_producer.py @@ -139,7 +139,7 @@ def fake_per_sample_loss(forward_output, target): num_classes = 5 - def fake_forward(self, x: Tensor): + def fake_forward(self, x: Tensor, sample_ids=None): return torch.zeros(x.shape[0], num_classes) with patch.object(DummyModyn, "forward", fake_forward): diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/irreducible_loss_producer.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/irreducible_loss_producer.py index 22a1e75cc..95186b17d 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/irreducible_loss_producer.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/rho_loss_utils/irreducible_loss_producer.py @@ -50,7 +50,7 @@ def get_irreducible_loss( with torch.inference_mode(): # move model to the computing device self.model.model.to(self.device) - forward_output = self.model.model(forward_input) + forward_output = self.model.model(forward_input, sample_ids) irreducible_loss = self.per_sample_loss_fct(forward_output, target).detach() cached_loss = irreducible_loss.cpu() self.model.model.to("cpu")