Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-add GDumb #120

Open
MaxiBoether opened this issue Jan 27, 2023 · 0 comments
Open

Re-add GDumb #120

MaxiBoether opened this issue Jan 27, 2023 · 0 comments

Comments

@MaxiBoether
Copy link
Contributor

MaxiBoether commented Jan 27, 2023

With #111, I remove GDumb because I don't want to be blocked by refactoring something that is not required for Hello World. We have the old cold in old commits, but for reference, here is the old implementation. We should re-add GDumb with #114 in mind.

import numpy as np
from modyn.backend.metadata_database.metadata_database_connection import MetadataDatabaseConnection
from modyn.backend.metadata_database.models.metadata import Metadata
from modyn.backend.selector.internal.selector_strategies.abstract_selection_strategy import AbstractSelectionStrategy


class GDumbStrategy(AbstractSelectionStrategy):
    """
    Implements the GDumb selection policy.
    """

    def _on_trigger(self, pipeline_id: int) -> list[tuple[str, float]]:
        """
        For a given pipeline_id and number of samples, request that many samples from the selector.

        Returns:
            List of keys for the samples to be considered, along with a default weight of 1.
        """
        result_samples, result_classes = [], []

        all_samples, all_classes = self._get_all_metadata(pipeline_id)
        classes, counts = np.unique(all_classes, return_counts=True)

        num_classes = classes.shape[0]
        if self.training_set_size_limit > 0:
            training_set_size = min(self.training_set_size_limit, len(all_samples))
        else:
            training_set_size = len(all_samples)
        for clss in range(num_classes):
            num_class_samples = counts[clss]
            rand_indices = np.random.choice(num_class_samples, size=training_set_size // num_classes, replace=False)
            class_indices = np.where(all_classes == classes[clss])[0][rand_indices]
            result_samples.append(np.array(all_samples)[class_indices])
            result_classes.append(np.array(all_classes)[class_indices])
        result_samples = np.concatenate(result_samples)
        return [(sample, 1.0) for sample in result_samples]

    def _get_all_metadata(self, pipeline_id: int) -> tuple[list[str], list[int]]:
        with MetadataDatabaseConnection(self._modyn_config) as database:
            all_metadata = (
                database.session.query(Metadata.key, Metadata.label).filter(Metadata.training_id == pipeline_id).all()
            )
        return ([metadata.key for metadata in all_metadata], [metadata.label for metadata in all_metadata])

    def inform_data(self, pipeline_id: int, keys: list[str], timestamps: list[int], labels: list[int]) -> None:
        with MetadataDatabaseConnection(self._modyn_config) as database:
            database.set_metadata(
                keys,
                timestamps,
                [None] * len(keys),
                [False] * len(keys),
                labels,
                [None] * len(keys),
                pipeline_id,
            )

The old tests:

# pylint: disable=no-value-for-parameter
import os
import pathlib
from collections import Counter
from unittest.mock import patch

from modyn.backend.metadata_database.metadata_database_connection import MetadataDatabaseConnection
from modyn.backend.metadata_database.models.metadata import Metadata
from modyn.backend.metadata_database.models.training import Training
from modyn.backend.selector.internal.selector_strategies.abstract_selection_strategy import AbstractSelectionStrategy
from modyn.backend.selector.internal.selector_strategies.gdumb_strategy import GDumbStrategy

database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db"


def get_minimal_modyn_config():
    return {
        "metadata_database": {
            "drivername": "sqlite",
            "username": "",
            "password": "",
            "host": "",
            "port": "0",
            "database": f"{database_path}",
        },
    }


def noop_constructor_mock(self, config=None, opt=None):  # pylint: disable=unused-argument
    self._modyn_config = get_minimal_modyn_config()


def setup():
    with MetadataDatabaseConnection(get_minimal_modyn_config()) as database:
        database.create_tables()

        training = Training(1)
        database.session.add(training)
        database.session.commit()

        metadata = Metadata("test_key", 100, 0.5, False, 1, b"test_data", training.training_id)

        metadata.metadata_id = 1  # SQLite does not support autoincrement for composite primary keys
        database.session.add(metadata)

        metadata2 = Metadata("test_key2", 101, 0.75, True, 2, b"test_data2", training.training_id)

        metadata2.metadata_id = 2  # SQLite does not support autoincrement for composite primary keys
        database.session.add(metadata2)

        database.session.commit()


def teardown():
    os.remove(database_path)


@patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set())
@patch.object(GDumbStrategy, "__init__", noop_constructor_mock)
def test_gdumb_selector_get_metadata():
    strategy = GDumbStrategy(None)
    assert strategy._get_all_metadata(1) == (["test_key", "test_key2"], [1, 2])


@patch.multiple(AbstractSelectionStrategy, __abstractmethods__=set())
@patch.object(GDumbStrategy, "__init__", noop_constructor_mock)
@patch.object(GDumbStrategy, "_get_all_metadata")
def test_gdumb_selector_get_new_training_samples(test__get_all_metadata):
    all_samples = ["a", "b", "c", "d", "e", "f", "g", "h"]
    all_classes = [1, 1, 1, 1, 2, 2, 3, 3]

    test__get_all_metadata.return_value = all_samples, all_classes

    selector = GDumbStrategy(None)  # pylint: disable=abstract-class-instantiated
    selector.training_set_size_limit = 6

    samples = selector._on_trigger(0)
    classes = [clss for _, clss in samples]
    samples = [sample for sample, _ in samples]

    assert len(classes) == len(samples) == 6
    assert Counter(classes) == Counter([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    for sample in samples:
        assert sample in all_samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants