-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Support for RS2 Downsampler (#465)
This implements the random selection from the [RS2 paper](https://openreview.net/pdf?id=JnRStoIuTe) minus the learning rate scheduling adjustments. Note that it is a bit suboptimal to use the downsampling infrastructure here (#466). We might want to think about making the selector a bit more dynamic, but for now, this will suffice to run experiments. #462 should be merged before this is reviewed.
- Loading branch information
1 parent
ac66eaf
commit 9fd2b80
Showing
8 changed files
with
330 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
...elector/internal/selector_strategies/downsampling_strategies/rs2_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from modyn.config.schema.sampling.downsampling_config import RS2DownsamplingConfig | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy | ||
|
||
|
||
class RS2DownsamplingStrategy(AbstractDownsamplingStrategy): | ||
def __init__( | ||
self, | ||
downsampling_config: RS2DownsamplingConfig, | ||
modyn_config: dict, | ||
pipeline_id: int, | ||
maximum_keys_in_memory: int, | ||
): | ||
super().__init__(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory) | ||
self.remote_downsampling_strategy_name = "RemoteRS2Downsampling" | ||
|
||
def _build_downsampling_params(self) -> dict: | ||
config = super()._build_downsampling_params() | ||
config["replacement"] = self.downsampling_config.with_replacement | ||
return config |
31 changes: 31 additions & 0 deletions
31
...or/internal/selector_strategies/downsampling_strategies/test_rs2_downsampling_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
import pathlib | ||
import tempfile | ||
|
||
from modyn.config.schema.sampling.downsampling_config import RS2DownsamplingConfig | ||
from modyn.selector.internal.selector_strategies.downsampling_strategies import RS2DownsamplingStrategy | ||
|
||
database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db" | ||
TMP_DIR = tempfile.mkdtemp() | ||
|
||
|
||
def test_init_rs2(): | ||
# Test init works | ||
strat = RS2DownsamplingStrategy( | ||
RS2DownsamplingConfig(ratio=10, with_replacement=True), | ||
{}, | ||
0, | ||
1000, | ||
) | ||
|
||
assert strat.downsampling_ratio == 10 | ||
assert strat.requires_remote_computation | ||
assert strat.maximum_keys_in_memory == 1000 | ||
|
||
name = strat.remote_downsampling_strategy_name | ||
assert isinstance(name, str) | ||
assert name == "RemoteRS2Downsampling" | ||
|
||
params = strat.downsampling_params | ||
assert "replacement" in params | ||
assert params["replacement"] |
171 changes: 171 additions & 0 deletions
171
...tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_rs2_downsampling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import torch | ||
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( | ||
get_tensors_subset, | ||
) | ||
from modyn.trainer_server.internal.trainer.remote_downsamplers.remote_rs2_downsampling import RemoteRS2Downsampling | ||
|
||
|
||
def test_init(): | ||
pipeline_id = 0 | ||
trigger_id = 0 | ||
batch_size = 32 | ||
params_from_selector = {"replacement": True, "downsampling_ratio": 50} | ||
per_sample_loss = None | ||
device = "cpu" | ||
|
||
downsampler = RemoteRS2Downsampling( | ||
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device | ||
) | ||
|
||
assert downsampler.pipeline_id == pipeline_id | ||
assert downsampler.trigger_id == trigger_id | ||
assert downsampler.batch_size == batch_size | ||
assert downsampler.device == device | ||
assert not downsampler.forward_required | ||
assert not downsampler.supports_bts | ||
assert downsampler._all_sample_ids == [] | ||
assert downsampler._subsets == [] | ||
assert downsampler._current_subset == -1 | ||
assert downsampler._with_replacement == params_from_selector["replacement"] | ||
assert downsampler._first_epoch | ||
|
||
|
||
def test_inform_samples(): | ||
pipeline_id = 0 | ||
trigger_id = 0 | ||
batch_size = 32 | ||
params_from_selector = {"replacement": True, "downsampling_ratio": 50} | ||
per_sample_loss = None | ||
device = "cpu" | ||
|
||
downsampler = RemoteRS2Downsampling( | ||
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device | ||
) | ||
|
||
sample_ids = [1, 2, 3, 4, 5] | ||
forward_output = torch.randn(5, 10) | ||
target = torch.randint(0, 10, (5,)) | ||
|
||
downsampler.inform_samples(sample_ids, forward_output, target) | ||
|
||
assert downsampler._all_sample_ids == sample_ids | ||
downsampler.inform_samples(sample_ids, forward_output, target) | ||
assert downsampler._all_sample_ids == 2 * sample_ids | ||
# Now it should not change anymore | ||
downsampler.select_points() | ||
downsampler.inform_samples(sample_ids, forward_output, target) | ||
assert set(downsampler._all_sample_ids) == set(sample_ids) | ||
assert len(downsampler._all_sample_ids) == 2 * len(sample_ids) | ||
|
||
|
||
def test_multiple_epochs_with_replacement(): | ||
pipeline_id = 0 | ||
trigger_id = 0 | ||
batch_size = 32 | ||
params_from_selector = {"replacement": True, "downsampling_ratio": 50} | ||
per_sample_loss = None | ||
device = "cpu" | ||
|
||
downsampler = RemoteRS2Downsampling( | ||
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device | ||
) | ||
with torch.inference_mode(mode=(not downsampler.requires_grad)): | ||
sample_ids = list(range(10)) | ||
data = torch.randn(10, 10) | ||
target = torch.randint(0, 10, (10,)) | ||
|
||
for _ in range(3): | ||
downsampler.inform_samples(sample_ids, data, target) | ||
selected_ids, weights = downsampler.select_points() | ||
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids) | ||
|
||
assert len(set(selected_ids)) == 5 | ||
assert weights.shape == (5,) | ||
assert all(idx in sample_ids for idx in selected_ids) | ||
assert sampled_data.shape == (5, 10) | ||
assert sampled_target.shape == (5,) | ||
|
||
|
||
def test_multiple_epochs_without_replacement(): | ||
pipeline_id = 0 | ||
trigger_id = 0 | ||
batch_size = 32 | ||
params_from_selector = {"replacement": False, "downsampling_ratio": 50} | ||
per_sample_loss = None | ||
device = "cpu" | ||
|
||
downsampler = RemoteRS2Downsampling( | ||
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device | ||
) | ||
with torch.inference_mode(mode=(not downsampler.requires_grad)): | ||
|
||
sample_ids = list(range(10)) | ||
data = torch.randn(10, 10) | ||
target = torch.randint(0, 10, (10,)) | ||
|
||
# Epoch 1 | ||
downsampler.inform_samples(sample_ids, data, target) | ||
epoch1_ids, weights = downsampler.select_points() | ||
sampled_data, sampled_target = get_tensors_subset(epoch1_ids, data, target, sample_ids) | ||
|
||
assert len(set(epoch1_ids)) == 5 | ||
assert weights.shape == (5,) | ||
assert all(idx in sample_ids for idx in epoch1_ids) | ||
assert sampled_data.shape == (5, 10) | ||
assert sampled_target.shape == (5,) | ||
|
||
# Epoch 2 | ||
downsampler.inform_samples(sample_ids, data, target) | ||
epoch2_ids, weights = downsampler.select_points() | ||
sampled_data, sampled_target = get_tensors_subset(epoch2_ids, data, target, sample_ids) | ||
|
||
assert len(set(epoch2_ids)) == 5 | ||
assert weights.shape == (5,) | ||
assert all(idx in sample_ids for idx in epoch2_ids) | ||
assert not any(idx in epoch1_ids for idx in epoch2_ids) # No overlap across epochs | ||
assert sampled_data.shape == (5, 10) | ||
assert sampled_target.shape == (5,) | ||
|
||
# Epoch 3 | ||
downsampler.inform_samples(sample_ids, data, target) | ||
epoch3_ids, weights = downsampler.select_points() | ||
sampled_data, sampled_target = get_tensors_subset(epoch3_ids, data, target, sample_ids) | ||
|
||
assert len(set(epoch3_ids)) == 5 | ||
assert weights.shape == (5,) | ||
assert all(idx in sample_ids for idx in epoch3_ids) | ||
assert all(idx in epoch1_ids or idx in epoch2_ids for idx in epoch3_ids) # There needs to be overlap now | ||
# but (with very high probability, this might be flaky lets see) there is some difference | ||
assert any(idx not in epoch1_ids for idx in epoch3_ids) | ||
assert sampled_data.shape == (5, 10) | ||
assert sampled_target.shape == (5,) | ||
|
||
|
||
def test_multiple_epochs_without_replacement_leftover_data(): | ||
pipeline_id = 0 | ||
trigger_id = 0 | ||
batch_size = 32 | ||
params_from_selector = {"replacement": False, "downsampling_ratio": 40} | ||
per_sample_loss = None | ||
device = "cpu" | ||
|
||
downsampler = RemoteRS2Downsampling( | ||
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device | ||
) | ||
with torch.inference_mode(mode=(not downsampler.requires_grad)): | ||
sample_ids = list(range(10)) | ||
data = torch.randn(10, 10) | ||
target = torch.randint(0, 10, (10,)) | ||
|
||
for _ in range(3): | ||
downsampler.inform_samples(sample_ids, data, target) | ||
|
||
selected_ids, weights = downsampler.select_points() | ||
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids) | ||
assert len(set(selected_ids)) == 4 | ||
assert weights.shape == (4,) | ||
assert sampled_data.shape == (4, 10) | ||
assert sampled_target.shape == (4,) | ||
|
||
assert all(idx in sample_ids for idx in selected_ids) | ||
assert len(set(selected_ids)) == len(selected_ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rs2_downsampling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import logging | ||
import random | ||
from typing import Any, Optional | ||
|
||
import torch | ||
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( | ||
AbstractRemoteDownsamplingStrategy, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RemoteRS2Downsampling(AbstractRemoteDownsamplingStrategy): | ||
""" | ||
Method adapted from REPEATED RANDOM SAMPLING FOR MINIMIZING THE TIME-TO-ACCURACY OF LEARNING (Okanovic+, 2024) | ||
https://openreview.net/pdf?id=JnRStoIuTe | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pipeline_id: int, | ||
trigger_id: int, | ||
batch_size: int, | ||
params_from_selector: dict, | ||
per_sample_loss: Any, | ||
device: str, | ||
) -> None: | ||
super().__init__(pipeline_id, trigger_id, batch_size, params_from_selector, device) | ||
self.forward_required = False | ||
self.supports_bts = False | ||
self._all_sample_ids: list[int] = [] | ||
self._subsets: list[list[int]] = [] | ||
self._current_subset = -1 | ||
self._with_replacement: bool = params_from_selector["replacement"] | ||
self._first_epoch = True | ||
|
||
def init_downsampler(self) -> None: | ||
pass # We take care of that in inform_samples | ||
|
||
def inform_samples( | ||
self, | ||
sample_ids: list[int], | ||
forward_output: torch.Tensor, | ||
target: torch.Tensor, | ||
embedding: Optional[torch.Tensor] = None, | ||
) -> None: | ||
# We only need to collect the sample information once | ||
if self._first_epoch: | ||
self._all_sample_ids.extend(sample_ids) | ||
|
||
def _epoch_step_wr(self, target_size: int) -> None: | ||
random.shuffle(self._all_sample_ids) | ||
self._subsets = [self._all_sample_ids[:target_size]] | ||
self._current_subset = 0 | ||
|
||
def _epoch_step_no_r(self, target_size: int) -> None: | ||
max_subset = len(self._all_sample_ids) // target_size | ||
self._current_subset += 1 | ||
# len(self._subsets) == 0 holds in the very first epoch | ||
if self._current_subset >= max_subset or len(self._subsets) == 0: | ||
random.shuffle(self._all_sample_ids) | ||
self._current_subset = 0 | ||
self._subsets = [self._all_sample_ids[i * target_size : (i + 1) * target_size] for i in range(max_subset)] | ||
|
||
def _epoch_step(self) -> None: | ||
target_size = max(int(self.downsampling_ratio * len(self._all_sample_ids) / 100), 1) | ||
|
||
if self._with_replacement: | ||
self._epoch_step_wr(target_size) | ||
else: | ||
self._epoch_step_no_r(target_size) | ||
|
||
def select_points(self) -> tuple[list[int], torch.Tensor]: | ||
self._first_epoch = False | ||
self._epoch_step() | ||
return self._subsets[self._current_subset], torch.ones(len(self._subsets[self._current_subset])) | ||
|
||
@property | ||
def requires_grad(self) -> bool: | ||
return False |