Skip to content

Commit

Permalink
feat: Support for RS2 Downsampler (#465)
Browse files Browse the repository at this point in the history
This implements the random selection from the [RS2
paper](https://openreview.net/pdf?id=JnRStoIuTe) minus the learning rate
scheduling adjustments.

Note that it is a bit suboptimal to use the downsampling infrastructure
here (#466). We might want to think about making the selector a bit more
dynamic, but for now, this will suffice to run experiments. #462 should
be merged before this is reviewed.
  • Loading branch information
MaxiBoether authored Jun 4, 2024
1 parent ac66eaf commit 9fd2b80
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 2 deletions.
14 changes: 14 additions & 0 deletions modyn/config/schema/sampling/downsampling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig):
il_training_config: ILTrainingConfig = Field(description="The configuration for the IL training.")


class RS2DownsamplingConfig(BaseDownsamplingConfig):
"""Config for the RS2 downsampling strategy."""

strategy: Literal["RS2"] = "RS2"
period: Literal[1] = 1 # RS2 needs to sample every epoch
sample_then_batch: Literal[True] = True # RS2 only supports StB
with_replacement: bool = Field(
description=(
"Whether we resample from the full TTS each epoch (= True) or train "
"on all the data with a different subset each epoch (= False)."
)
)


SingleDownsamplingConfig = Annotated[
Union[
UncertaintyDownsamplingConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .kcentergreedy_downsampling_strategy import KcenterGreedyDownsamplingStrategy # noqa: F401
from .loss_downsampling_strategy import LossDownsamplingStrategy # noqa: F401
from .no_downsampling_strategy import NoDownsamplingStrategy # noqa: F401
from .rs2_downsampling_strategy import RS2DownsamplingStrategy # noqa: F401
from .submodular_downsampling_strategy import SubmodularDownsamplingStrategy # noqa: F401
from .uncertainty_downsampling_strategy import UncertaintyDownsamplingStrategy # noqa: F401
from .utils import instantiate_downsampler # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from modyn.config.schema.sampling.downsampling_config import RS2DownsamplingConfig
from modyn.selector.internal.selector_strategies.downsampling_strategies import AbstractDownsamplingStrategy


class RS2DownsamplingStrategy(AbstractDownsamplingStrategy):
def __init__(
self,
downsampling_config: RS2DownsamplingConfig,
modyn_config: dict,
pipeline_id: int,
maximum_keys_in_memory: int,
):
super().__init__(downsampling_config, modyn_config, pipeline_id, maximum_keys_in_memory)
self.remote_downsampling_strategy_name = "RemoteRS2Downsampling"

def _build_downsampling_params(self) -> dict:
config = super()._build_downsampling_params()
config["replacement"] = self.downsampling_config.with_replacement
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import pathlib
import tempfile

from modyn.config.schema.sampling.downsampling_config import RS2DownsamplingConfig
from modyn.selector.internal.selector_strategies.downsampling_strategies import RS2DownsamplingStrategy

database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db"
TMP_DIR = tempfile.mkdtemp()


def test_init_rs2():
# Test init works
strat = RS2DownsamplingStrategy(
RS2DownsamplingConfig(ratio=10, with_replacement=True),
{},
0,
1000,
)

assert strat.downsampling_ratio == 10
assert strat.requires_remote_computation
assert strat.maximum_keys_in_memory == 1000

name = strat.remote_downsampling_strategy_name
assert isinstance(name, str)
assert name == "RemoteRS2Downsampling"

params = strat.downsampling_params
assert "replacement" in params
assert params["replacement"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import torch
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
get_tensors_subset,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.remote_rs2_downsampling import RemoteRS2Downsampling


def test_init():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)

assert downsampler.pipeline_id == pipeline_id
assert downsampler.trigger_id == trigger_id
assert downsampler.batch_size == batch_size
assert downsampler.device == device
assert not downsampler.forward_required
assert not downsampler.supports_bts
assert downsampler._all_sample_ids == []
assert downsampler._subsets == []
assert downsampler._current_subset == -1
assert downsampler._with_replacement == params_from_selector["replacement"]
assert downsampler._first_epoch


def test_inform_samples():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)

sample_ids = [1, 2, 3, 4, 5]
forward_output = torch.randn(5, 10)
target = torch.randint(0, 10, (5,))

downsampler.inform_samples(sample_ids, forward_output, target)

assert downsampler._all_sample_ids == sample_ids
downsampler.inform_samples(sample_ids, forward_output, target)
assert downsampler._all_sample_ids == 2 * sample_ids
# Now it should not change anymore
downsampler.select_points()
downsampler.inform_samples(sample_ids, forward_output, target)
assert set(downsampler._all_sample_ids) == set(sample_ids)
assert len(downsampler._all_sample_ids) == 2 * len(sample_ids)


def test_multiple_epochs_with_replacement():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):
sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

for _ in range(3):
downsampler.inform_samples(sample_ids, data, target)
selected_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids)

assert len(set(selected_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in selected_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)


def test_multiple_epochs_without_replacement():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": False, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):

sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

# Epoch 1
downsampler.inform_samples(sample_ids, data, target)
epoch1_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch1_ids, data, target, sample_ids)

assert len(set(epoch1_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch1_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)

# Epoch 2
downsampler.inform_samples(sample_ids, data, target)
epoch2_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch2_ids, data, target, sample_ids)

assert len(set(epoch2_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch2_ids)
assert not any(idx in epoch1_ids for idx in epoch2_ids) # No overlap across epochs
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)

# Epoch 3
downsampler.inform_samples(sample_ids, data, target)
epoch3_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch3_ids, data, target, sample_ids)

assert len(set(epoch3_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch3_ids)
assert all(idx in epoch1_ids or idx in epoch2_ids for idx in epoch3_ids) # There needs to be overlap now
# but (with very high probability, this might be flaky lets see) there is some difference
assert any(idx not in epoch1_ids for idx in epoch3_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)


def test_multiple_epochs_without_replacement_leftover_data():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": False, "downsampling_ratio": 40}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):
sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

for _ in range(3):
downsampler.inform_samples(sample_ids, data, target)

selected_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids)
assert len(set(selected_ids)) == 4
assert weights.shape == (4,)
assert sampled_data.shape == (4, 10)
assert sampled_target.shape == (4,)

assert all(idx in sample_ids for idx in selected_ids)
assert len(set(selected_ids)) == len(selected_ids)
6 changes: 4 additions & 2 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches
self._info(f"Training will stop when the number of samples to pass reaches {self.num_samples_to_pass}.")

if self._downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
# assertion since model validation by pydantic should catch this.
assert self._downsampler.supports_bts, "The downsampler does not support batch then sample"
# We cannot pass the target size from the trainer server since that depends on StB vs BtS.
post_downsampling_size = max(int(self._downsampler.downsampling_ratio * self._batch_size / 100), 1)
assert post_downsampling_size < self._batch_size
Expand Down Expand Up @@ -692,7 +694,7 @@ def downsample_batch(
self.start_embedding_recording_if_needed()

with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
big_batch_output = self._model.model(data)
big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, big_batch_output, target, embeddings)

Expand Down Expand Up @@ -831,7 +833,7 @@ def _iterate_dataloader_and_compute_scores(
with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
with torch.autocast(self._device_type, enabled=self._amp):
# compute the scores and accumulate them
model_output = self._model.model(data)
model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, model_output, target, embeddings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def __init__(
# CoresetSupportingModule for model implementations.
self.requires_coreset_supporting_module = False

# Some methods might not need information from forward pass (e.g. completely random)
# Most do (definition), hence we default to True
# We might want to refactor those downsamplers to presamplers and support some
# adaptivity at the selector, but for now we allow random downsamplers mostly
# to support RS2.
self.forward_required = True

# Some methods might only support StB, not BtS.
self.supports_bts = True

@abstractmethod
def init_downsampler(self) -> None:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
import random
from typing import Any, Optional

import torch
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
)

logger = logging.getLogger(__name__)


class RemoteRS2Downsampling(AbstractRemoteDownsamplingStrategy):
"""
Method adapted from REPEATED RANDOM SAMPLING FOR MINIMIZING THE TIME-TO-ACCURACY OF LEARNING (Okanovic+, 2024)
https://openreview.net/pdf?id=JnRStoIuTe
"""

def __init__(
self,
pipeline_id: int,
trigger_id: int,
batch_size: int,
params_from_selector: dict,
per_sample_loss: Any,
device: str,
) -> None:
super().__init__(pipeline_id, trigger_id, batch_size, params_from_selector, device)
self.forward_required = False
self.supports_bts = False
self._all_sample_ids: list[int] = []
self._subsets: list[list[int]] = []
self._current_subset = -1
self._with_replacement: bool = params_from_selector["replacement"]
self._first_epoch = True

def init_downsampler(self) -> None:
pass # We take care of that in inform_samples

def inform_samples(
self,
sample_ids: list[int],
forward_output: torch.Tensor,
target: torch.Tensor,
embedding: Optional[torch.Tensor] = None,
) -> None:
# We only need to collect the sample information once
if self._first_epoch:
self._all_sample_ids.extend(sample_ids)

def _epoch_step_wr(self, target_size: int) -> None:
random.shuffle(self._all_sample_ids)
self._subsets = [self._all_sample_ids[:target_size]]
self._current_subset = 0

def _epoch_step_no_r(self, target_size: int) -> None:
max_subset = len(self._all_sample_ids) // target_size
self._current_subset += 1
# len(self._subsets) == 0 holds in the very first epoch
if self._current_subset >= max_subset or len(self._subsets) == 0:
random.shuffle(self._all_sample_ids)
self._current_subset = 0
self._subsets = [self._all_sample_ids[i * target_size : (i + 1) * target_size] for i in range(max_subset)]

def _epoch_step(self) -> None:
target_size = max(int(self.downsampling_ratio * len(self._all_sample_ids) / 100), 1)

if self._with_replacement:
self._epoch_step_wr(target_size)
else:
self._epoch_step_no_r(target_size)

def select_points(self) -> tuple[list[int], torch.Tensor]:
self._first_epoch = False
self._epoch_step()
return self._subsets[self._current_subset], torch.ones(len(self._subsets[self._current_subset]))

@property
def requires_grad(self) -> bool:
return False

0 comments on commit 9fd2b80

Please sign in to comment.