You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With #111, I remove GDumb because I don't want to be blocked by refactoring something that is not required for Hello World. We have the old cold in old commits, but for reference, here is the old implementation. We should re-add GDumb with #114 in mind.
import numpy as np
from modyn.backend.metadata_database.metadata_database_connection import MetadataDatabaseConnection
from modyn.backend.metadata_database.models.metadata import Metadata
from modyn.backend.selector.internal.selector_strategies.abstract_selection_strategy import AbstractSelectionStrategy
class GDumbStrategy(AbstractSelectionStrategy):
"""
Implements the GDumb selection policy.
"""
def _on_trigger(self, pipeline_id: int) -> list[tuple[str, float]]:
"""
For a given pipeline_id and number of samples, request that many samples from the selector.
Returns:
List of keys for the samples to be considered, along with a default weight of 1.
"""
result_samples, result_classes = [], []
all_samples, all_classes = self._get_all_metadata(pipeline_id)
classes, counts = np.unique(all_classes, return_counts=True)
num_classes = classes.shape[0]
if self.training_set_size_limit > 0:
training_set_size = min(self.training_set_size_limit, len(all_samples))
else:
training_set_size = len(all_samples)
for clss in range(num_classes):
num_class_samples = counts[clss]
rand_indices = np.random.choice(num_class_samples, size=training_set_size // num_classes, replace=False)
class_indices = np.where(all_classes == classes[clss])[0][rand_indices]
result_samples.append(np.array(all_samples)[class_indices])
result_classes.append(np.array(all_classes)[class_indices])
result_samples = np.concatenate(result_samples)
return [(sample, 1.0) for sample in result_samples]
def _get_all_metadata(self, pipeline_id: int) -> tuple[list[str], list[int]]:
with MetadataDatabaseConnection(self._modyn_config) as database:
all_metadata = (
database.session.query(Metadata.key, Metadata.label).filter(Metadata.training_id == pipeline_id).all()
)
return ([metadata.key for metadata in all_metadata], [metadata.label for metadata in all_metadata])
def inform_data(self, pipeline_id: int, keys: list[str], timestamps: list[int], labels: list[int]) -> None:
with MetadataDatabaseConnection(self._modyn_config) as database:
database.set_metadata(
keys,
timestamps,
[None] * len(keys),
[False] * len(keys),
labels,
[None] * len(keys),
pipeline_id,
)
The old tests:
# pylint: disable=no-value-for-parameter
import os
import pathlib
from collections import Counter
from unittest.mock import patch
from modyn.backend.metadata_database.metadata_database_connection import MetadataDatabaseConnection
from modyn.backend.metadata_database.models.metadata import Metadata
from modyn.backend.metadata_database.models.training import Training
from modyn.backend.selector.internal.selector_strategies.abstract_selection_strategy import AbstractSelectionStrategy
from modyn.backend.selector.internal.selector_strategies.gdumb_strategy import GDumbStrategy
database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db"
def get_minimal_modyn_config():
return {
"metadata_database": {
"drivername": "sqlite",
"username": "",
"password": "",
"host": "",
"port": "0",
"database": f"{database_path}",
},
}
def noop_constructor_mock(self, config=None, opt=None): # pylint: disable=unused-argument
self._modyn_config = get_minimal_modyn_config()
def setup():
with MetadataDatabaseConnection(get_minimal_modyn_config()) as database:
database.create_tables()
training = Training(1)
database.session.add(training)
database.session.commit()
metadata = Metadata("test_key", 100, 0.5, False, 1, b"test_data", training.training_id)
metadata.metadata_id = 1 # SQLite does not support autoincrement for composite primary keys
database.session.add(metadata)
metadata2 = Metadata("test_key2", 101, 0.75, True, 2, b"test_data2", training.training_id)
metadata2.metadata_id = 2 # SQLite does not support autoincrement for composite primary keys
database.session.add(metadata2)
database.session.commit()
def teardown():
os.remove(database_path)
@patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set())
@patch.object(GDumbStrategy, "__init__", noop_constructor_mock)
def test_gdumb_selector_get_metadata():
strategy = GDumbStrategy(None)
assert strategy._get_all_metadata(1) == (["test_key", "test_key2"], [1, 2])
@patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set())
@patch.object(GDumbStrategy, "__init__", noop_constructor_mock)
@patch.object(GDumbStrategy, "_get_all_metadata")
def test_gdumb_selector_get_new_training_samples(test__get_all_metadata):
all_samples = ["a", "b", "c", "d", "e", "f", "g", "h"]
all_classes = [1, 1, 1, 1, 2, 2, 3, 3]
test__get_all_metadata.return_value = all_samples, all_classes
selector = GDumbStrategy(None) # pylint: disable=abstract-class-instantiated
selector.training_set_size_limit = 6
samples = selector._on_trigger(0)
classes = [clss for _, clss in samples]
samples = [sample for sample, _ in samples]
assert len(classes) == len(samples) == 6
assert Counter(classes) == Counter([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
for sample in samples:
assert sample in all_samples
The text was updated successfully, but these errors were encountered:
With #111, I remove GDumb because I don't want to be blocked by refactoring something that is not required for Hello World. We have the old cold in old commits, but for reference, here is the old implementation. We should re-add GDumb with #114 in mind.
The old tests:
The text was updated successfully, but these errors were encountered: