-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Twin RHO Model Step 1: create the Twin RHO Model (#547)
This PR is the first PR to implement another way of producing holdout set, il model and irreducible loss (typically suitable for small datasets): > Split training set into two halves; train two IL models, each on one half. Each model provides the irreducible loss for samples that it was not trained on. The main model is still trained on the original training set (CIFAR10, CIFAR100, CINIC-10). Our current architecture only allow one trigger id to correspond to one model id. To accommodate two il models within one trigger, I create a "twin model" which internally consists of two il models. During training, each il model will memorize the sample ids it has seen, so that during evaluation each il model will be used for the samples the model hasn't seen. ## How it works 1. At selector, `RHOLossDownsamplingStrategy` randomly samples half of the training set and mark the `used` column in `selector_state_metadata` table of those samples as `True`. The strategy issues a request to train a `RHOLOSSTwinModel` on this TSS. (unimplemented) 2. At trainer server, `RHOLOSSTwinModel` is instantiated. Only the 0th model is trained on this dataset (implemented in this PR). 3. At selector, `RHOLossDownsamplingStrategy` produces the other half of the training set by selecting the samples with `used==False`. The strategy issues a request to finetune this twin model. (unimplemented) 4. At trainer server, `RHOLOSSTwinModel` is instantiated again. Only the 1th model is trained on this dataset (implemented in this PR). 5. At selector, (optionally) clear the `used` flags. 6. At trainer server when training the main model: nothing needed to change as the logic is handled internally in the twin model. Apparently it is not the optimal way to train a twin RHO model, but it's a very straightforward way and we can optimize it depending on how well it performs. ## Current drawbacks Due to `used` `RHOLoss` will currently be not compatible with some presampling strategies that also use `used` fields such as `FreshnessSamplingStrategy`. ## Next PR Implementing step 1 and 3: preparing the split holdout set. ## How to review All the main logic is in [modyn/models/rho_loss_twin_model/rho_loss_twin_model.py](https://github.com/eth-easl/modyn/pull/547/files#diff-0f510b51f60a2c4ee867551fed01763b09ca13d3b88b10aac5aca55d83377fdf)
- Loading branch information
Showing
18 changed files
with
320 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
from typing import Any, Optional | ||
|
||
import torch | ||
from modyn.utils import dynamic_module_import | ||
from torch import nn | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RHOLOSSTwinModel: | ||
|
||
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: | ||
self.model = RHOLOSSTwinModelModyn(model_configuration, device, amp) | ||
self.model.to(device) | ||
|
||
|
||
class RHOLOSSTwinModelModyn(nn.Module): | ||
|
||
def __init__(self, model_configuration: dict[str, Any], device: str, amp: bool) -> None: | ||
super().__init__() | ||
self.device = device | ||
model_module = dynamic_module_import("modyn.models") | ||
rho_model_class = model_configuration["rho_real_model_class"] | ||
rho_model_config = model_configuration["rho_real_model_config"] | ||
model_handler = getattr(model_module, rho_model_class) | ||
# we only need the inner model, not the wrapper | ||
self._models = nn.ModuleList( | ||
[ | ||
model_handler(rho_model_config, device, amp).model, | ||
model_handler(rho_model_config, device, amp).model, | ||
] | ||
) | ||
self._models_seen_ids: list[set[int]] = [set(), set()] | ||
self._current_model = 0 | ||
|
||
def get_extra_state(self) -> dict: | ||
return { | ||
"models_seen_ids": self._models_seen_ids, | ||
} | ||
|
||
def set_extra_state(self, state: dict) -> None: | ||
self._models_seen_ids = state["models_seen_ids"] | ||
# the second time we train on this model, we should switch to the other model | ||
self._current_model = 1 | ||
|
||
def forward(self, data: torch.Tensor, sample_ids: Optional[list[int]] = None) -> torch.Tensor: | ||
assert sample_ids is not None | ||
# self.training is an internal attribute defined in nn.Module that is updated | ||
# whenever .eval() or .train() is called | ||
if self.training: | ||
output_tensor = self._training_forward(sample_ids, data) | ||
else: | ||
output_tensor = self._eval_forward(sample_ids, data) | ||
return output_tensor | ||
|
||
def _training_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor: | ||
self._models_seen_ids[self._current_model].update(sample_ids) | ||
return self._models[self._current_model](data) | ||
|
||
def _eval_forward(self, sample_ids: list[int], data: torch.Tensor) -> torch.Tensor: | ||
seen_by_model0 = torch.BoolTensor( | ||
[sample_id in self._models_seen_ids[0] for sample_id in sample_ids], device=self.device | ||
) | ||
seen_by_model1 = torch.BoolTensor( | ||
[sample_id in self._models_seen_ids[1] for sample_id in sample_ids], device=self.device | ||
) | ||
|
||
# if model 0 did not see any sample, we route all samples to model 0 | ||
if not seen_by_model0.any(): | ||
return self._models[0](data) | ||
# if model 1 did not see any sample, we route all samples to model 1 | ||
if not seen_by_model1.any(): | ||
return self._models[1](data) | ||
|
||
# when a sample is not seen by any model, we route it to model 0 | ||
# unsqueeze to make seen_by_model1 broadcastable | ||
return torch.where(seen_by_model0.unsqueeze(1), self._models[1](data), self._models[0](data)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.