Skip to content

Commit

Permalink
Twin RHO Model Step 2: split the training set and train the twin model (
Browse files Browse the repository at this point in the history
#552)

Per the title.

The major change is in file
[modyn/selector/internal/selector_strategies/downsampling_strategies/rho_loss_downsampling_strategy.py](https://github.com/eth-easl/modyn/pull/552/files#diff-1628acbf60458002fa86b1b4ad8b71fb21deedc12390e21bc4620d6197cdab04)
  • Loading branch information
XianzheMa authored Jul 1, 2024
1 parent 60d3baa commit cdf1280
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 141 deletions.
5 changes: 0 additions & 5 deletions integrationtests/supervisor/integrationtest_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,6 @@ def test_rho_loss_pipeline_with_two_triggers() -> None:
model_ids = [model.model_id for model in models]
assert latest_il_model_id == max(model_ids)

# validate the holdout set is in correct proportion
# in rho_loss.yaml, each trigger contains 10 images. The holdout set ratio is 30%
# so the holdout set should contain 3 images
assert all(trigger.num_keys == 3 for trigger in triggers)


if __name__ == "__main__":
tiny_dataset_helper = TinyDatasetHelper(dataset_size=20)
Expand Down
8 changes: 8 additions & 0 deletions modyn/config/schema/pipeline/sampling/downsampling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,20 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig):
min=0,
max=100,
)
holdout_set_ratio_max: int = Field(
description="Reference maximum holdout_set_ratio value. Defaults to 100, which implies percent."
" If you set this to 1000, holdout_set_ratio describes promille instead.",
default=100,
min=1,
)
il_training_config: ILTrainingConfig = Field(description="The configuration for the IL training.")

@model_validator(mode="after")
def validate_holdout_set_ratio(self) -> Self:
if self.holdout_set_strategy == "Twin" and self.holdout_set_ratio != 50:
raise ValueError("holdout_set_ratio should be 100 for the Twin strategy.")
if self.holdout_set_ratio > self.holdout_set_ratio_max:
raise ValueError("holdout_set_ratio cannot be greater than holdout_set_ratio_max.")
return self


Expand Down
6 changes: 6 additions & 0 deletions modyn/metadata_database/models/selector_state_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class SelectorStateMetadataMixin:
sample_key = Column("sample_key", BIGINT, primary_key=True)
seen_in_trigger_id = Column("seen_in_trigger_id", Integer, primary_key=True)
used = Column("used", Boolean, default=False)
# This is a field to help selection strategies to track the state of the samples during preparing the
# post-presampling training set. It should be reset to 0 after the training set is prepared.
# For example, the RHOLossDownsamplingStrategy uses this field to split the training set into two halves
# by first sampling 50% and setting their `tmp_version` field to 1,
# Then it generates the remaining 50% by querying samples with `tmp_version` as 0.
tmp_version = Column("tmp_version", Integer, default=0)
timestamp = Column("timestamp", BigInteger)
label = Column("label", Integer)
last_used_in_trigger = Column("last_used_in_trigger", Integer, default=-1)
Expand Down
9 changes: 4 additions & 5 deletions modyn/models/rho_loss_twin_model/rho_loss_twin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class RHOLOSSTwinModelModyn(nn.Module):

def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
super().__init__()
self.device = device
model_module = dynamic_module_import("modyn.models")
rho_model_class = model_configuration["rho_real_model_class"]
rho_model_config = model_configuration["rho_real_model_config"]
Expand Down Expand Up @@ -59,11 +58,11 @@ def _training_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.
return self._models[self._current_model](data)

def _eval_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor:
seen_by_model0 = torch.BoolTensor(
[sample_id in self._models_seen_ids[0] for sample_id in sample_ids], device=self.device
seen_by_model0 = torch.tensor(
[sample_id in self._models_seen_ids[0] for sample_id in sample_ids], dtype=torch.bool, device=data.device
)
seen_by_model1 = torch.BoolTensor(
[sample_id in self._models_seen_ids[1] for sample_id in sample_ids], device=self.device
seen_by_model1 = torch.tensor(
[sample_id in self._models_seen_ids[1] for sample_id in sample_ids], dtype=torch.bool, device=data.device
)

# if model 0 did not see any sample, we route all samples to model 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from modyn.metadata_database.utils import ModelStorageStrategyConfig
from modyn.selector.internal.selector_strategies import AbstractSelectionStrategy
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy
from modyn.selector.internal.selector_strategies.utils import get_trigger_dataset_size
from modyn.selector.internal.storage_backend import AbstractStorageBackend
from modyn.selector.internal.storage_backend.database import DatabaseStorageBackend
from sqlalchemy import Select, func, select
from sqlalchemy.orm.session import Session

logger = logging.getLogger(__name__)

Expand All @@ -34,6 +34,7 @@ def __init__(
maximum_keys_in_memory=maximum_keys_in_memory, tail_triggers=0
)
self.holdout_set_ratio = downsampling_config.holdout_set_ratio
self.holdout_set_ratio_max = downsampling_config.holdout_set_ratio_max
self.holdout_set_strategy = downsampling_config.holdout_set_strategy
self.il_training_config = downsampling_config.il_training_config
self.grpc = TrainerServerGRPCHandlerMixin(modyn_config)
Expand All @@ -46,24 +47,80 @@ def __init__(
def inform_next_trigger(self, next_trigger_id: int, selector_storage_backend: AbstractStorageBackend) -> None:
if not isinstance(selector_storage_backend, DatabaseStorageBackend):
raise ValueError("RHOLossDownsamplingStrategy requires a DatabaseStorageBackend")

probability = self.holdout_set_ratio / self.holdout_set_ratio_max

query = self._get_sampling_query(self._pipeline_id, next_trigger_id, probability, selector_storage_backend)
rho_state = self._get_latest_rho_state(self.rho_pipeline_id, self._modyn_config)
if rho_state is not None:
next_rho_trigger_id = rho_state[0] + 1
last_model_id = rho_state[1]
else:
logger.info(f"No previous state found for pipeline {self.rho_pipeline_id}.")
next_rho_trigger_id = 0
last_model_id = None
self._persist_holdout_set(query, next_rho_trigger_id, selector_storage_backend)
if self.il_training_config.use_previous_model:
if self.holdout_set_strategy == "Twin":
raise NotImplementedError("Use previous model currently is not supported for Twin strategy")
previous_model_id = last_model_id
else:
previous_model_id = None

model_id = self._train_il_model(next_rho_trigger_id, previous_model_id)
logger.info(
f"Trained IL model {model_id} for trigger {next_rho_trigger_id} in rho pipeline"
f"{self.rho_pipeline_id} with rho trigger id {next_trigger_id}."
)
if self.holdout_set_strategy == "Twin":
raise NotImplementedError("Twin holdout set strategy is not implemented yet.")
self._prepare_holdout_set(next_trigger_id, selector_storage_backend)
self._train_il_model(next_trigger_id)
second_next_trigger_id = next_rho_trigger_id + 1
second_query = self._get_rest_data_query(self._pipeline_id, next_trigger_id)
self._persist_holdout_set(second_query, second_next_trigger_id, selector_storage_backend)
self._train_il_model(second_next_trigger_id, model_id)
logger.info(
f"Twin strategy: Trained second IL model for trigger {next_trigger_id} in rho pipeline "
f"{self.rho_pipeline_id} with rho trigger id {next_trigger_id}."
)
self._clean_tmp_version(self._pipeline_id, next_trigger_id, selector_storage_backend)

@staticmethod
def _clean_tmp_version(
main_pipeline_id: int, trigger_id: int, selector_storage_backend: AbstractStorageBackend
) -> None:
assert isinstance(selector_storage_backend, DatabaseStorageBackend)

def _session_callback(session: Session) -> None:
session.query(SelectorStateMetadata).filter(
SelectorStateMetadata.pipeline_id == main_pipeline_id,
SelectorStateMetadata.seen_in_trigger_id == trigger_id,
).update({"tmp_version": 0})
session.commit()

selector_storage_backend._execute_on_session(_session_callback)

@staticmethod
def _get_rest_data_query(main_pipeline_id: int, trigger_id: int) -> Select:
stmt = select(SelectorStateMetadata.sample_key).filter(
SelectorStateMetadata.pipeline_id == main_pipeline_id,
SelectorStateMetadata.seen_in_trigger_id == trigger_id,
SelectorStateMetadata.tmp_version == 0,
)

return stmt

@property
def downsampling_params(self) -> dict:
config = super().downsampling_params
config["rho_pipeline_id"] = self.rho_pipeline_id
il_model_id = self._get_latest_il_model_id(self.rho_pipeline_id, self._modyn_config)
assert il_model_id is not None
config["il_model_id"] = il_model_id
state = self._get_latest_rho_state(self.rho_pipeline_id, self._modyn_config)
assert state is not None
config["il_model_id"] = state[1]
return config

@staticmethod
def _get_latest_il_model_id(rho_pipeline_id: int, modyn_config: dict) -> Optional[int]:
def _get_latest_rho_state(rho_pipeline_id: int, modyn_config: dict) -> Optional[Tuple[int, int]]:
with MetadataDatabaseConnection(modyn_config) as database:
# find the maximal trigger id. This is the current trigger id.
# find the maximal trigger id
max_trigger_id = (
database.session.query(func.max(Trigger.trigger_id))
.filter(Trigger.pipeline_id == rho_pipeline_id)
Expand All @@ -79,13 +136,9 @@ def _get_latest_il_model_id(rho_pipeline_id: int, modyn_config: dict) -> Optiona
.scalar()
)
assert il_model_id is not None
return il_model_id
return max_trigger_id, il_model_id

def _train_il_model(self, trigger_id: int) -> int:
if self.il_training_config.use_previous_model:
previous_model_id = self._get_latest_il_model_id(self.rho_pipeline_id, self._modyn_config)
else:
previous_model_id = None
def _train_il_model(self, trigger_id: int, previous_model_id: Optional[int]) -> int:
training_id = self.grpc.start_training(
pipeline_id=self.rho_pipeline_id,
trigger_id=trigger_id,
Expand Down Expand Up @@ -135,16 +188,10 @@ def _create_rho_pipeline_id(self, database: MetadataDatabaseConnection, data_con
)
return rho_pipeline_id

def _prepare_holdout_set(self, next_trigger_id: int, selector_storage_backend: AbstractStorageBackend) -> None:
current_trigger_dataset_size = get_trigger_dataset_size(
selector_storage_backend, self._pipeline_id, next_trigger_id, tail_triggers=0
)

holdout_set_size = max(int(current_trigger_dataset_size * self.holdout_set_ratio / 100), 1)

stmt = self._get_holdout_sampling_query(self._pipeline_id, next_trigger_id, holdout_set_size).execution_options(
yield_per=self.maximum_keys_in_memory
)
def _persist_holdout_set(
self, query: Select, target_trigger_id: int, selector_storage_backend: AbstractStorageBackend
) -> None:
stmt = query.execution_options(yield_per=self.maximum_keys_in_memory)

def training_set_producer() -> Iterable[tuple[list[tuple[int, float]], dict[str, Any]]]:
with MetadataDatabaseConnection(self._modyn_config) as database:
Expand All @@ -154,24 +201,36 @@ def training_set_producer() -> Iterable[tuple[list[tuple[int, float]], dict[str,

total_keys_in_trigger, *_ = AbstractSelectionStrategy.store_training_set(
self.rho_pipeline_id,
next_trigger_id,
target_trigger_id,
self._modyn_config,
training_set_producer,
selector_storage_backend.insertion_threads,
)
logger.info(
f"Stored {total_keys_in_trigger} keys in the holdout set for trigger {next_trigger_id} "
f"Stored {total_keys_in_trigger} keys in the holdout set for trigger {target_trigger_id} "
f"in rho pipeline {self.rho_pipeline_id}"
)

@staticmethod
def _get_holdout_sampling_query(main_pipeline_id: int, trigger_id: int, target_size: int) -> Select:
return (
select(SelectorStateMetadata.sample_key)
.filter(
def _get_sampling_query(
main_pipeline_id: int, trigger_id: int, probability: float, selector_storage_backend: AbstractStorageBackend
) -> Select:
assert isinstance(selector_storage_backend, DatabaseStorageBackend)

def _session_callback(session: Session) -> None:
session.query(SelectorStateMetadata).filter(
SelectorStateMetadata.pipeline_id == main_pipeline_id,
SelectorStateMetadata.seen_in_trigger_id == trigger_id,
)
.order_by(func.random()) # pylint: disable=E1102
.limit(target_size)
func.random() < probability, # pylint: disable=not-callable
).update({"tmp_version": 1})
session.commit()

selector_storage_backend._execute_on_session(_session_callback)

stmt = select(SelectorStateMetadata.sample_key).filter(
SelectorStateMetadata.pipeline_id == main_pipeline_id,
SelectorStateMetadata.seen_in_trigger_id == trigger_id,
SelectorStateMetadata.tmp_version == 1,
)

return stmt
Loading

0 comments on commit cdf1280

Please sign in to comment.