Skip to content

Commit

Permalink
Twin RHO Model Step 1: create the Twin RHO Model (#547)
Browse files Browse the repository at this point in the history
This PR is the first PR to implement another way of producing holdout
set, il model and irreducible loss (typically suitable for small
datasets):
> Split training set into two halves; train two IL models, each on one
half. Each model provides the irreducible loss for samples that it was
not trained on. The main model is still trained on the original training
set (CIFAR10, CIFAR100, CINIC-10).

Our current architecture only allow one trigger id to correspond to one
model id. To accommodate two il models within one trigger, I create a
"twin model" which internally consists of two il models. During
training, each il model will memorize the sample ids it has seen, so
that during evaluation each il model will be used for the samples the
model hasn't seen.

## How it works

1. At selector, `RHOLossDownsamplingStrategy` randomly samples half of
the training set and mark the `used` column in `selector_state_metadata`
table of those samples as `True`. The strategy issues a request to train
a `RHOLOSSTwinModel` on this TSS. (unimplemented)
2. At trainer server, `RHOLOSSTwinModel` is instantiated. Only the 0th
model is trained on this dataset (implemented in this PR).
3. At selector, `RHOLossDownsamplingStrategy` produces the other half of
the training set by selecting the samples with `used==False`. The
strategy issues a request to finetune this twin model. (unimplemented)
4. At trainer server, `RHOLOSSTwinModel` is instantiated again. Only the
1th model is trained on this dataset (implemented in this PR).
5. At selector, (optionally) clear the `used` flags.
6. At trainer server when training the main model: nothing needed to
change as the logic is handled internally in the twin model.

Apparently it is not the optimal way to train a twin RHO model, but it's
a very straightforward way and we can optimize it depending on how well
it performs.

## Current drawbacks
Due to `used` `RHOLoss` will currently be not compatible with some
presampling strategies that also use `used` fields such as
`FreshnessSamplingStrategy`.

## Next PR
Implementing step 1 and 3: preparing the split holdout set. 

## How to review
All the main logic is in
[modyn/models/rho_loss_twin_model/rho_loss_twin_model.py](https://github.com/eth-easl/modyn/pull/547/files#diff-0f510b51f60a2c4ee867551fed01763b09ca13d3b88b10aac5aca55d83377fdf)
  • Loading branch information
XianzheMa authored Jul 1, 2024
1 parent 59ea026 commit 60d3baa
Show file tree
Hide file tree
Showing 18 changed files with 320 additions and 27 deletions.
1 change: 1 addition & 0 deletions integrationtests/config/rho_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ selection_strategy:
tail_triggers: 0
downsampling_config:
strategy: RHOLoss
holdout_set_strategy: Simple
sample_then_batch: False
holdout_set_ratio: 30
ratio: 60
Expand Down
13 changes: 13 additions & 0 deletions modyn/config/schema/pipeline/sampling/downsampling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,26 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig):
"""Config for the RHO Loss downsampling strategy."""

strategy: Literal["RHOLoss"] = "RHOLoss"
holdout_set_strategy: Literal["Simple", "Twin"] = Field(
description="Simple: holdout set is a subset randomly sampled from the training set based on the"
"holdout_set_ratio. The holdout set is used to train the il model and the original training set is"
"used to train the main model. Twin: training set is split into two halves. Each half is used to "
"train a separate il model. Each il model provides the irreducible loss for the samples that the"
"model is not trained on. The original training set is used to train the main model."
)
holdout_set_ratio: int = Field(
description="How much of the training set is used as the holdout set.",
min=0,
max=100,
)
il_training_config: ILTrainingConfig = Field(description="The configuration for the IL training.")

@model_validator(mode="after")
def validate_holdout_set_ratio(self) -> Self:
if self.holdout_set_strategy == "Twin" and self.holdout_set_ratio != 50:
raise ValueError("holdout_set_ratio should be 100 for the Twin strategy.")
return self


class RS2DownsamplingConfig(BaseDownsamplingConfig):
"""Config for the RS2 downsampling strategy."""
Expand Down
1 change: 1 addition & 0 deletions modyn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .resnet18.resnet18 import ResNet18 # noqa: F401
from .resnet50.resnet50 import ResNet50 # noqa: F401
from .resnet152.resnet152 import ResNet152 # noqa: F401
from .rho_loss_twin_model.rho_loss_twin_model import RHOLOSSTwinModel # noqa: F401
from .yearbooknet.yearbooknet import YearbookNet # noqa: F401

files = os.listdir(os.path.dirname(__file__))
Expand Down
5 changes: 3 additions & 2 deletions modyn/models/articlenet/articlenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=W0223
from typing import Any
from typing import Any, Optional

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
Expand Down Expand Up @@ -51,7 +51,8 @@ def __init__(self, num_classes: int) -> None:
self.featurizer = DistilBertFeaturizer.from_pretrained("distilbert-base-uncased")
self.classifier = nn.Linear(self.featurizer.d_out, num_classes)

def forward(self, data: torch.Tensor) -> torch.Tensor:
def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
embedding = self.featurizer(data)
embedding = self.embedding_recorder(embedding)
return self.classifier(embedding)
Expand Down
3 changes: 2 additions & 1 deletion modyn/models/dlrm/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def reorder_categorical_input(self, cat_input: torch.Tensor) -> torch.Tensor:
order = self._embedding_ordering.expand(dim0, -1)
return torch.gather(cat_input, 1, order)

def forward(self, data: torch.Tensor) -> torch.Tensor:
def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
"""
Args:
data: a dict containing:
numerical_input (Tensor): with shape [batch_size, num_numerical_features]
categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
"""
del sample_ids
numerical_input = data["numerical_input"]
categorical_input = data["categorical_input"]

Expand Down
5 changes: 3 additions & 2 deletions modyn/models/dummy/dummy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

from torch import Tensor, nn

Expand All @@ -16,5 +16,6 @@ def __init__(self, model_configuration: dict[str, Any]) -> None:
super().__init__()
self.output = nn.Linear(2, 2)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: Tensor, sample_ids: Optional[list[int]] = None) -> Tensor:
del sample_ids
return self.output(x)
5 changes: 3 additions & 2 deletions modyn/models/fmownet/fmownet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as func
Expand Down Expand Up @@ -27,7 +27,8 @@ def __init__(self, num_classes: int) -> None:
self.enc = densenet121(pretrained=True).features
self.classifier = nn.Linear(1024, self.num_classes)

def forward(self, data: torch.Tensor) -> torch.Tensor:
def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
features = self.enc(data)
out = func.relu(features, inplace=True)
out = func.adaptive_avg_pool2d(out, (1, 1))
Expand Down
6 changes: 5 additions & 1 deletion modyn/models/resnet152/resnet152.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
Expand Down Expand Up @@ -56,5 +56,9 @@ def _forward_impl(self, x: Tensor) -> Tensor:

return x

def forward(self, x: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
return super().forward(x)

def get_last_layer(self) -> nn.Module:
return self.fc
6 changes: 5 additions & 1 deletion modyn/models/resnet18/resnet18.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
Expand Down Expand Up @@ -56,5 +56,9 @@ def _forward_impl(self, x: Tensor) -> Tensor:

return x

def forward(self, x: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
return super().forward(x)

def get_last_layer(self) -> nn.Module:
return self.fc
6 changes: 5 additions & 1 deletion modyn/models/resnet50/resnet50.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
Expand Down Expand Up @@ -56,5 +56,9 @@ def _forward_impl(self, x: Tensor) -> Tensor:

return x

def forward(self, x: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
return super().forward(x)

def get_last_layer(self) -> nn.Module:
return self.fc
Empty file.
78 changes: 78 additions & 0 deletions modyn/models/rho_loss_twin_model/rho_loss_twin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from typing import Any, Optional

import torch
from modyn.utils import dynamic_module_import
from torch import nn

logger = logging.getLogger(__name__)


class RHOLOSSTwinModel:

def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
self.model = RHOLOSSTwinModelModyn(model_configuration, device, amp)
self.model.to(device)


class RHOLOSSTwinModelModyn(nn.Module):

def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None:
super().__init__()
self.device = device
model_module = dynamic_module_import("modyn.models")
rho_model_class = model_configuration["rho_real_model_class"]
rho_model_config = model_configuration["rho_real_model_config"]
model_handler = getattr(model_module, rho_model_class)
# we only need the inner model, not the wrapper
self._models = nn.ModuleList(
[
model_handler(rho_model_config, device, amp).model,
model_handler(rho_model_config, device, amp).model,
]
)
self._models_seen_ids: list[set[int]] = [set(), set()]
self._current_model = 0

def get_extra_state(self) -> dict:
return {
"models_seen_ids": self._models_seen_ids,
}

def set_extra_state(self, state: dict) -> None:
self._models_seen_ids = state["models_seen_ids"]
# the second time we train on this model, we should switch to the other model
self._current_model = 1

def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
assert sample_ids is not None
# self.training is an internal attribute defined in nn.Module that is updated
# whenever .eval() or .train() is called
if self.training:
output_tensor = self._training_forward(sample_ids, data)
else:
output_tensor = self._eval_forward(sample_ids, data)
return output_tensor

def _training_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor:
self._models_seen_ids[self._current_model].update(sample_ids)
return self._models[self._current_model](data)

def _eval_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor:
seen_by_model0 = torch.BoolTensor(
[sample_id in self._models_seen_ids[0] for sample_id in sample_ids], device=self.device
)
seen_by_model1 = torch.BoolTensor(
[sample_id in self._models_seen_ids[1] for sample_id in sample_ids], device=self.device
)

# if model 0 did not see any sample, we route all samples to model 0
if not seen_by_model0.any():
return self._models[0](data)
# if model 1 did not see any sample, we route all samples to model 1
if not seen_by_model1.any():
return self._models[1](data)

# when a sample is not seen by any model, we route it to model 0
# unsqueeze to make seen_by_model1 broadcastable
return torch.where(seen_by_model0.unsqueeze(1), self._models[1](data), self._models[0](data))
5 changes: 3 additions & 2 deletions modyn/models/yearbooknet/yearbooknet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
from modyn.models.coreset_methods_support import CoresetSupportingModule
Expand Down Expand Up @@ -35,7 +35,8 @@ def conv_block(self, in_channels: int, out_channels: int) -> nn.Module:
nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2)
)

def forward(self, data: torch.Tensor) -> torch.Tensor:
def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor:
del sample_ids
data = self.enc(data)
data = torch.mean(data, dim=(2, 3))
data = self.embedding_recorder(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class RHOLossDownsamplingStrategy(AbstractDownsamplingStrategy):

IL_MODEL_STORAGE_STRATEGY = ModelStorageStrategyConfig(name="PyTorchFullModel")

def __init__(
Expand All @@ -35,6 +34,7 @@ def __init__(
maximum_keys_in_memory=maximum_keys_in_memory, tail_triggers=0
)
self.holdout_set_ratio = downsampling_config.holdout_set_ratio
self.holdout_set_strategy = downsampling_config.holdout_set_strategy
self.il_training_config = downsampling_config.il_training_config
self.grpc = TrainerServerGRPCHandlerMixin(modyn_config)
self.grpc.init_trainer_server()
Expand All @@ -46,7 +46,8 @@ def __init__(
def inform_next_trigger(self, next_trigger_id: int, selector_storage_backend: AbstractStorageBackend) -> None:
if not isinstance(selector_storage_backend, DatabaseStorageBackend):
raise ValueError("RHOLossDownsamplingStrategy requires a DatabaseStorageBackend")

if self.holdout_set_strategy == "Twin":
raise NotImplementedError("Twin holdout set strategy is not implemented yet.")
self._prepare_holdout_set(next_trigger_id, selector_storage_backend)
self._train_il_model(next_trigger_id)

Expand Down Expand Up @@ -113,13 +114,20 @@ def _get_or_create_rho_pipeline_id_and_get_data_config(self) -> Tuple[int, DataC
return rho_pipeline_id, DataConfig.model_validate_json(data_config_str)

def _create_rho_pipeline_id(self, database: MetadataDatabaseConnection, data_config_str: str) -> int:
# Actually we don't need to store configs in the database as we just need the existence of the rho pipeline.
# We fetch configs directly from the object fields.
# But for consistency, it is no harm to store the correct configs instead of dummy value in the database.
if self.holdout_set_strategy == "Twin":
model_class_name = "RHOLOSSTwinModel"
model_config = {
"rho_real_model_class": self.il_training_config.il_model_id,
"rho_real_model_config": self.il_training_config.il_model_config,
}
else:
model_class_name = self.il_training_config.il_model_id
model_config = self.il_training_config.il_model_config

rho_pipeline_id = database.register_pipeline(
num_workers=self.il_training_config.dataloader_workers,
model_class_name=self.il_training_config.il_model_id,
model_config=json.dumps(self.il_training_config.il_model_config),
model_class_name=model_class_name,
model_config=json.dumps(model_config),
amp=self.il_training_config.amp,
selection_strategy=self.il_model_dummy_selection_strategy.model_dump_json(by_alias=True),
data_config=data_config_str,
Expand Down
Loading

0 comments on commit 60d3baa

Please sign in to comment.