Skip to content

Commit

Permalink
refactor: Add BatchedTrigger (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Aug 27, 2024
1 parent 9f1e643 commit 76e95d4
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 346 deletions.
14 changes: 9 additions & 5 deletions benchmark/sigmod/triggering/run_yb_triggering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import sys
from pathlib import Path

from benchmark.sigmod.triggering.yearbook_triggering_config import gen_yearbook_triggering_config
from benchmark.sigmod.triggering.yearbook_triggering_config import (
gen_yearbook_triggering_config,
)
from experiments.utils.experiment_runner import run_multiple_pipelines
from modyn.config.schema.pipeline import ModynPipelineConfig
from modyn.config.schema.pipeline.trigger import TriggerConfig
from modyn.config.schema.pipeline.trigger.data_amount import DataAmountTriggerConfig
from modyn.config.schema.pipeline.trigger.drift import DataDriftTriggerConfig
from modyn.config.schema.pipeline.trigger.drift.alibi_detect import AlibiDetectMmdDriftMetric
from modyn.config.schema.pipeline.trigger.drift.alibi_detect import (
AlibiDetectMmdDriftMetric,
)
from modyn.config.schema.pipeline.trigger.drift.config import TimeWindowingStrategy
from modyn.config.schema.pipeline.trigger.time import TimeTriggerConfig
from modyn.supervisor.internal.pipeline_executor.models import PipelineLogs
Expand Down Expand Up @@ -42,18 +46,18 @@ def gen_triggering_strategies() -> list[tuple[str, TriggerConfig]]:
strategies.append((f"amounttrigger_{count}", DataAmountTriggerConfig(num_samples=count)))

# DriftTriggers
for detection_interval_data_points in [250, 500, 100]:
for evaluation_interval_data_points in [250, 500, 100]:
for threshold in [0.05, 0.07, 0.09]:
for window_size in ["1d", "2d", "5d"]: # fake timestamps, hence days
conf = DataDriftTriggerConfig(
detection_interval_data_points=detection_interval_data_points,
evaluation_interval_data_points=evaluation_interval_data_points,
windowing_strategy=TimeWindowingStrategy(limit=window_size),
reset_current_window_on_trigger=False,
metrics={
"mmd_alibi": AlibiDetectMmdDriftMetric(device="cpu", num_permutations=None, threshold=threshold)
},
)
name = f"mmdalibi_{detection_interval_data_points}_{threshold}_{window_size}"
name = f"mmdalibi_{evaluation_interval_data_points}_{threshold}_{window_size}"
strategies.append((name, conf))

return strategies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ data:
tokenizer: DistilBertTokenizerTransform
trigger:
id: DataDriftTrigger
detection_interval_data_points: 100000
evaluation_interval_data_points: 100000
metrics:
ev_mmd:
id: EvidentlyModelDriftMetric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ data:

trigger:
id: DataDriftTrigger
detection_interval_data_points: 5000
evaluation_interval_data_points: 5000
metrics:
ev_mmd:
id: AlibiDetectMmdDriftMetric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ data:
trigger:
id: DataDriftTrigger
detection_interval_data_points: 1000
evaluation_interval_data_points: 1000
metrics:
ev_mmd:
id: EvidentlyModelDriftMetric
Expand Down
2 changes: 1 addition & 1 deletion experiments/arxiv/compare_trigger_policies/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]:
gen_pipeline_config(
name=f"datadrifttrigger_{interval}",
trigger=DataDriftTriggerConfig(
detection_interval_data_points=interval,
evaluation_interval_data_points=interval,
metrics=experiment.drift_trigger_metrics,
aggregation_strategy=MajorityVoteDriftAggregationStrategy(),
),
Expand Down
2 changes: 1 addition & 1 deletion experiments/huffpost/compare_trigger_policies/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]:
gen_pipeline_config(
name=f"{experiment.name}_drift_{interval}",
trigger=DataDriftTriggerConfig(
detection_interval_data_points=interval,
evaluation_interval_data_points=interval,
metrics=experiment.drift_trigger_metrics,
aggregation_strategy=MajorityVoteDriftAggregationStrategy(),
),
Expand Down
2 changes: 1 addition & 1 deletion experiments/yearbook/compare_trigger_policies/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]:
gen_pipeline_config(
name=f"{experiment.name}_drift_{interval}",
trigger=DataDriftTriggerConfig(
detection_interval_data_points=interval,
evaluation_interval_data_points=interval,
metrics=drift_metrics,
aggregation_strategy=MajorityVoteDriftAggregationStrategy(),
),
Expand Down
12 changes: 12 additions & 0 deletions modyn/config/schema/pipeline/trigger/common/batched.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import Field

from modyn.config.schema.base_model import ModynBaseModel


class BatchedTriggerConfig(ModynBaseModel):
evaluation_interval_data_points: int = Field(
description=(
"Specifies after how many samples another believe update (query density "
"estimation, accuracy evaluation, drift detection, ...) should be performed."
)
)
14 changes: 2 additions & 12 deletions modyn/config/schema/pipeline/trigger/drift/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import Field, model_validator

from modyn.config.schema.base_model import ModynBaseModel
from modyn.config.schema.pipeline.trigger.common.batched import BatchedTriggerConfig
from modyn.config.schema.pipeline.trigger.drift.detection_window import (
AmountWindowingStrategy,
DriftWindowingStrategy,
Expand All @@ -23,19 +23,9 @@
]


class DataDriftTriggerConfig(ModynBaseModel):
class DataDriftTriggerConfig(BatchedTriggerConfig):
id: Literal["DataDriftTrigger"] = Field("DataDriftTrigger")

detection_interval: __TriggerConfig | None = Field( # type: ignore[valid-type]
None,
description="The Trigger policy to determine the interval at which drift detection is performed.",
) # currently not used
detection_interval_data_points: int = Field(
1000,
description="The number of samples in the interval after which drift detection is performed.",
ge=1,
)

windowing_strategy: DriftWindowingStrategy = Field(
AmountWindowingStrategy(),
description="Which windowing strategy to use for current and reference data",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from modyn.config.schema.base_model import ModynBaseModel
from modyn.config.schema.pipeline.evaluation.config import EvalDataConfig
from modyn.config.schema.pipeline.trigger.common.batched import BatchedTriggerConfig
from modyn.config.schema.pipeline.trigger.performance.criterion import (
PerformanceTriggerCriterion,
)
Expand All @@ -28,13 +29,7 @@ def validate_metrics(cls, dataset: EvalDataConfig) -> EvalDataConfig:
return dataset


class _InternalPerformanceTriggerConfig(ModynBaseModel):
detection_interval_data_points: int = Field(
description=(
"Specifies after how many samples another believe update (query density "
"estimation, accuracy evaluation) should be performed."
)
)
class _InternalPerformanceTriggerConfig(BatchedTriggerConfig):
data_density_window_size: int = Field(
0,
description="The window size for the data density estimation. Only used for lookahead mode.",
Expand Down
7 changes: 6 additions & 1 deletion modyn/supervisor/internal/triggers/amounttrigger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Generator

from typing_extensions import override

from modyn.config.schema.pipeline import DataAmountTriggerConfig
from modyn.supervisor.internal.triggers.trigger import Trigger
from modyn.supervisor.internal.triggers.utils.models import TriggerPolicyEvaluationLog
Expand All @@ -16,8 +18,11 @@ def __init__(self, config: DataAmountTriggerConfig):

assert self.data_points_for_trigger > 0, "data_points_for_trigger needs to be at least 1"

@override
def inform(
self, new_data: list[tuple[int, int, int]], log: TriggerPolicyEvaluationLog | None = None
self,
new_data: list[tuple[int, int, int]],
log: TriggerPolicyEvaluationLog | None = None,
) -> Generator[int, None, None]:
assert self.remaining_data_points < self.data_points_for_trigger, "Inconsistent remaining datapoints"

Expand Down
84 changes: 84 additions & 0 deletions modyn/supervisor/internal/triggers/batchedtrigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import logging
from abc import abstractmethod
from collections.abc import Generator

from typing_extensions import override

from modyn.config.schema.pipeline.trigger.common.batched import BatchedTriggerConfig
from modyn.supervisor.internal.triggers.trigger import Trigger
from modyn.supervisor.internal.triggers.utils.models import TriggerPolicyEvaluationLog

logger = logging.getLogger(__name__)


class BatchedTrigger(Trigger):
"""Abstract child of Trigger that implements triggering in discrete
intervals."""

def __init__(self, config: BatchedTriggerConfig) -> None:
self.config = config

# allows to detect drift in a fixed interval
self._sample_left_until_detection = config.evaluation_interval_data_points

self._leftover_data: list[tuple[int, int]] = []
"""Stores data that was not processed in the last inform call because
the detection interval was not filled."""

@override
def inform(
self,
new_data: list[tuple[int, int, int]],
log: TriggerPolicyEvaluationLog | None = None,
) -> Generator[int, None, None]:
new_key_ts = self._leftover_data + [(key, timestamp) for key, timestamp, _ in new_data]
# reappending the leftover data to the new data requires incrementing the sample left until detection
self._sample_left_until_detection += len(self._leftover_data)

# index of the first unprocessed data point in the batch
processing_head_in_batch = 0

# Go through remaining data in new data in batches of `detect_interval`
while True:
if self._sample_left_until_detection - len(new_key_ts) > 0:
# No detection in this trigger because of too few data points to fill detection interval
self._leftover_data = new_key_ts
self._sample_left_until_detection -= len(new_key_ts)
return

# At least one detection, fill up window up to that detection
next_detection_interval = new_key_ts[: self._sample_left_until_detection]

# Update the remaining data
processing_head_in_batch += len(next_detection_interval)
new_key_ts = new_key_ts[len(next_detection_interval) :]

# we need to return an index in the `new_data`. Therefore, we need to subtract number of samples in the
# leftover data from the processing head in batch; -1 is required as the head points to the first
# unprocessed data point
trigger_candidate_idx = min(
max(processing_head_in_batch - len(self._leftover_data) - 1, 0),
len(new_data) - 1,
)

# Reset for next detection
self._sample_left_until_detection = self.config.evaluation_interval_data_points

# ----------------------------------------------- Detection ---------------------------------------------- #

triggered = self._evaluate_batch(next_detection_interval, trigger_candidate_idx)

# ----------------------------------------------- Response ----------------------------------------------- #

if triggered:
yield trigger_candidate_idx

@abstractmethod
def _evaluate_batch(
self,
batch: list[tuple[int, int]],
trigger_candidate_idx: int,
log: TriggerPolicyEvaluationLog | None = None,
) -> bool: ...
Loading

0 comments on commit 76e95d4

Please sign in to comment.