From ec161fa28d45cc78cae2e593d49f0e2886f66efa Mon Sep 17 00:00:00 2001 From: Robin Holzinger Date: Wed, 14 Aug 2024 13:44:42 +0200 Subject: [PATCH] feat: Add more powerful drift windowing strategies, warmup and dynamic thresholds (#564) --- analytics/app/pages/plots/eval_over_time.py | 2 +- benchmark/huffpost_kaggle/README.md | 1 - .../data_drift_trigger/arxiv_datadrift.yaml | 4 +- .../huffpost_datadrift.yaml | 4 + .../yearbook_datadrift.yaml | 4 +- docs/pipeline/TRIGGERING.md | 18 +- docs/pipeline/triggering/DRIFT_TRIGGERS.md | 247 +++++++++++++ .../schema/pipeline/trigger/__init__.py | 8 +- .../schema/pipeline/trigger/drift/__init__.py | 1 + .../pipeline/trigger/drift/alibi_detect.py | 19 +- .../schema/pipeline/trigger/drift/config.py | 77 ++-- .../drift/detection_window/__init__.py | 11 + .../trigger/drift/detection_window/amount.py | 17 + .../trigger/drift/detection_window/time_.py | 35 ++ .../trigger/drift/detection_window/window.py | 15 + .../pipeline/trigger/drift/evidently.py | 14 +- .../schema/pipeline/trigger/drift/metric.py | 31 ++ .../schema/pipeline/trigger/drift/result.py | 4 +- .../pipeline/trigger/simple/__init__.py | 11 + .../trigger/{ => simple}/data_amount.py | 0 .../pipeline/trigger/{ => simple}/time.py | 0 .../pipeline_executor/pipeline_executor.py | 193 ++++++++-- .../internal/triggers/datadrifttrigger.py | 342 +++++++++++++----- .../triggers/drift/decision_policy.py | 67 ++++ .../drift/detection_window/__init__.py | 0 .../triggers/drift/detection_window/amount.py | 68 ++++ .../triggers/drift/detection_window/time_.py | 73 ++++ .../triggers/drift/detection_window/window.py | 31 ++ .../triggers/drift/detector/__init__.py | 0 .../{alibi_detector.py => detector/alibi.py} | 46 ++- .../{drift_detector.py => detector/drift.py} | 8 +- .../evidently.py} | 33 +- modyn/supervisor/internal/triggers/models.py | 5 +- .../internal/triggers/utils/__init__.py | 0 .../internal/triggers/utils/factory.py | 10 + .../internal/triggers/{ => utils}/utils.py | 26 +- .../triggers/drift/detection_window/amount.py | 294 +++++++++++++++ .../triggers/drift/detection_window/time_.py | 291 +++++++++++++++ .../triggers/drift/test_alibi_detector.py | 13 +- .../triggers/drift/test_decision_policy.py | 63 ++++ .../triggers/drift/test_evidently_detector.py | 21 +- .../triggers/test_datadrifttrigger.py | 210 +++++++---- 42 files changed, 2007 insertions(+), 310 deletions(-) create mode 100644 docs/pipeline/triggering/DRIFT_TRIGGERS.md create mode 100644 modyn/config/schema/pipeline/trigger/drift/detection_window/__init__.py create mode 100644 modyn/config/schema/pipeline/trigger/drift/detection_window/amount.py create mode 100644 modyn/config/schema/pipeline/trigger/drift/detection_window/time_.py create mode 100644 modyn/config/schema/pipeline/trigger/drift/detection_window/window.py create mode 100644 modyn/config/schema/pipeline/trigger/drift/metric.py create mode 100644 modyn/config/schema/pipeline/trigger/simple/__init__.py rename modyn/config/schema/pipeline/trigger/{ => simple}/data_amount.py (100%) rename modyn/config/schema/pipeline/trigger/{ => simple}/time.py (100%) create mode 100644 modyn/supervisor/internal/triggers/drift/decision_policy.py create mode 100644 modyn/supervisor/internal/triggers/drift/detection_window/__init__.py create mode 100644 modyn/supervisor/internal/triggers/drift/detection_window/amount.py create mode 100644 modyn/supervisor/internal/triggers/drift/detection_window/time_.py create mode 100644 modyn/supervisor/internal/triggers/drift/detection_window/window.py create mode 100644 modyn/supervisor/internal/triggers/drift/detector/__init__.py rename modyn/supervisor/internal/triggers/drift/{alibi_detector.py => detector/alibi.py} (81%) rename modyn/supervisor/internal/triggers/drift/{drift_detector.py => detector/drift.py} (74%) rename modyn/supervisor/internal/triggers/drift/{evidently_detector.py => detector/evidently.py} (78%) create mode 100644 modyn/supervisor/internal/triggers/utils/__init__.py create mode 100644 modyn/supervisor/internal/triggers/utils/factory.py rename modyn/supervisor/internal/triggers/{ => utils}/utils.py (82%) create mode 100644 modyn/tests/supervisor/internal/triggers/drift/detection_window/amount.py create mode 100644 modyn/tests/supervisor/internal/triggers/drift/detection_window/time_.py create mode 100644 modyn/tests/supervisor/internal/triggers/drift/test_decision_policy.py diff --git a/analytics/app/pages/plots/eval_over_time.py b/analytics/app/pages/plots/eval_over_time.py index 1562dbec2..a0d422b7a 100644 --- a/analytics/app/pages/plots/eval_over_time.py +++ b/analytics/app/pages/plots/eval_over_time.py @@ -58,7 +58,7 @@ def gen_figure( # we only want the pipeline performance (composed of the models active periods stitched together) df_adjusted = df_adjusted[df_adjusted[composite_model_variant]] else: - assert df_adjusted["pipeline_ref"].nunique() == 1 + assert df_adjusted["pipeline_ref"].nunique() <= 1 # add the pipeline time series which is the performance of different models stitched together dep. # w.r.t which model was active pipeline_composite_model = df_adjusted[df_adjusted[composite_model_variant]] diff --git a/benchmark/huffpost_kaggle/README.md b/benchmark/huffpost_kaggle/README.md index b0ab04c45..8d80cfd06 100644 --- a/benchmark/huffpost_kaggle/README.md +++ b/benchmark/huffpost_kaggle/README.md @@ -7,7 +7,6 @@ In this directory, you can find the files necessary to run experiments using the The goal is to predict the tag of news given headlines. The dataset contains more than 60k samples collected from 2012 to 2018. Titles belonging to the same year are grouped into the same CSV file and stored together. -Each year is mapped to a year starting from 1/1/1970. There is a total of 42 categories/classes. > Note: The wild-time variant of the huffpost dataset has only 11 classes. This is due to the fact that diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml index 2762dc690..aa5f3e9db 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/arxiv_datadrift.yaml @@ -45,7 +45,9 @@ trigger: metrics: ev_mmd: id: EvidentlyModelDriftMetric - threshold: 0.7 + decision_criterion: + id: ThresholdDecisionCriterion + threshold: 0.7 aggregation_strategy: id: MajorityVote selection_strategy: diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml index 34a6b68e0..6d6bc675c 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/huffpost_datadrift.yaml @@ -47,6 +47,10 @@ trigger: ev_mmd: id: AlibiDetectMmdDriftMetric num_permutations: 1000 + decision_criterion: + id: ThresholdDecisionCriterion + threshold: 0.7 + aggregation_strategy: id: MajorityVote selection_strategy: diff --git a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml index e15d3022f..266094f97 100644 --- a/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml +++ b/benchmark/wildtime_benchmarks/example_pipelines/data_drift_trigger/yearbook_datadrift.yaml @@ -48,7 +48,9 @@ trigger: metrics: ev_mmd: id: EvidentlyModelDriftMetric - threshold: 0.7 + decision_criterion: + id: ThresholdDecisionCriterion + threshold: 0.7 aggregation_strategy: id: MajorityVote selection_strategy: diff --git a/docs/pipeline/TRIGGERING.md b/docs/pipeline/TRIGGERING.md index 389d14ea2..e231c3276 100644 --- a/docs/pipeline/TRIGGERING.md +++ b/docs/pipeline/TRIGGERING.md @@ -25,36 +25,36 @@ classDiagram class TimeTrigger { } - class DataAmount { + class DataAmountTrigger { } } namespace complex_triggers { - class DataDrift { + class DataDriftTrigger { } - class CostBased { + class CostBasedTrigger { } class _BatchedTrigger { <> } - class EnsemblePolicy { + class EnsembleTrigger { } } Trigger <|-- _BatchedTrigger - Trigger <|-- EnsemblePolicy + Trigger <|-- EnsembleTrigger Trigger <|-- TimeTrigger - Trigger <|-- DataAmount + Trigger <|-- DataAmountTrigger - _BatchedTrigger <|-- DataDrift - _BatchedTrigger <|-- CostBased + _BatchedTrigger <|-- DataDriftTrigger + _BatchedTrigger <|-- CostBasedTrigger - EnsemblePolicy *-- "n" Trigger + EnsembleTrigger *-- "n" Trigger ``` diff --git a/docs/pipeline/triggering/DRIFT_TRIGGERS.md b/docs/pipeline/triggering/DRIFT_TRIGGERS.md new file mode 100644 index 000000000..70d46f423 --- /dev/null +++ b/docs/pipeline/triggering/DRIFT_TRIGGERS.md @@ -0,0 +1,247 @@ +# Drift Triggering + +## Overview + +As can be seen in [TRIGGERING.md](../TRIGGERING.md) the `DataDriftTrigger` follows the same interface as the simple triggers. The main difference is that the `DataDriftTrigger` is a complex trigger that requires more information to be provided to the trigger. + +It utilizes `DetectionWindows` to select samples for drift detection and `DriftDetector` to measure distance between the current and reference data distributions. The `DriftDecisionPolicy` is used to generate a binary decision based on the distance metric using a specific criterion like a threshold or dynamic threshold. Hypothesis testing can also be used to come to a decision, however, we found this to be oversensitive and not very useful in practice. + +### Main Architecture + +```mermaid +classDiagram + class DetectionWindows { + <> + } + + class DriftDetector + + class Trigger { + <> + +void init_trigger(TriggerContext context) + +Generator[Triggers] inform(new_data) + +void inform_previous_model(int previous_model_id) + } + + class TimeTrigger { + } + + class DataAmountTrigger { + } + + class DriftDetector { + <> + } + + class DataDriftTrigger { + +DataDriftTriggerConfig config + +DetectionWindows _windows + +dict[MetricId, DriftDecisionPolicy] decision_policies + +dict[MetricId, DriftDetector] distance_detectors + + +void init_trigger(TriggerContext context) + +Generator[triggers] inform(new_data) + +void inform_previous_model(int previous_model_id) + } + + class DriftDecisionPolicy { + <> + } + + DataDriftTrigger "warmup_trigger" *-- "1" Trigger + DataDriftTrigger *-- "1" DataDriftTriggerConfig + DataDriftTrigger *-- "1" DetectionWindows + DataDriftTrigger *-- "|metrics|" DriftDetector + DataDriftTrigger *-- "|metrics|" DriftDecisionPolicy + + Trigger <|-- DataDriftTrigger + Trigger <|-- TimeTrigger + Trigger <|-- DataAmountTrigger +``` + +### DetectionWindows Hierarchy + +The `DetectionWindows` class serves as the abstract base for specific windowing strategies like `AmountDetectionWindows` and `TimeDetectionWindows`. These classes are responsible for both storing the actual data windows for reference and current data and for defining a strategy for updating and managing these windows. + +```mermaid +classDiagram + class DetectionWindows { + <> + +Deque current + +Deque current_reservoir + +Deque reference + +void inform_data(list[tuple[int, int]]) + +void inform_trigger() + } + + class AmountDetectionWindows { + +AmountWindowingStrategy config + +void inform_data(list[tuple[int, int]]) + +void inform_trigger() + } + + class TimeDetectionWindows { + +TimeWindowingStrategy config + +void inform_data(list[tuple[int, int]]) + +void inform_trigger() + } + + DetectionWindows <|-- AmountDetectionWindows + DetectionWindows <|-- TimeDetectionWindows +``` + +### DriftDetector Hierarchy + +The `DriftDetector` class is an abstract base class for detectors like `AlibiDriftDetector` and `EvidentlyDriftDetector`, which use different metrics to measure the distance between the current and reference data distributions. +Both the underlying drift detection packages generate their own binary drift decision through hypothesis testing or threshold. In the `DriftDetector` we only use the distance metric +and later derive a binary decision based from that using our own threshold based decision policies. Therefore we ignore the binary decision generated by the underlying drift detection packages. + +The `BaseMetric` class hierarchy is a series of Pydantic configuration classes while the `Detectors` are actual business logic classes that implement the distance calculation. + +```mermaid +classDiagram + class DriftDetector { + <> + +dict[MetricId, DriftMetric] metrics_config + +void init_detector() + +dict[MetricId, MetricResult] detect_drift(embeddings_ref, embeddings_cur, bool is_warmup) + } + + class AlibiDriftDetector + class EvidentlyDriftDetector + class BaseMetric { + decision_criterion: DecisionCriterion + } + + DriftDetector <|-- AlibiDriftDetector + DriftDetector <|-- EvidentlyDriftDetector + + + AlibiDriftDetector *-- "|metrics|" AlibiDetectDriftMetric + EvidentlyDriftDetector *-- "|metrics|" EvidentlyDriftMetric + + class AlibiDetectDriftMetric { + <> + } + + class AlibiDetectMmdDriftMetric { + } + + class AlibiDetectCVMDriftMetric { + } + + class AlibiDetectKSDriftMetric { + } + + BaseMetric <|-- AlibiDetectDriftMetric + AlibiDetectDriftMetric <|-- AlibiDetectMmdDriftMetric + AlibiDetectDriftMetric <|-- AlibiDetectCVMDriftMetric + AlibiDetectDriftMetric <|-- AlibiDetectKSDriftMetric + + class EvidentlyDriftMetric { + <> + int num_pca_component + } + + class EvidentlyModelDriftMetric { + bool bootstrap = False + float quantile_probability = 0.95 + float threshold = 0.55 + } + + class EvidentlyRatioDriftMetric { + string component_stattest = "wasserstein" + float component_stattest_threshold = 0.1 + float threshold = 0.2 + } + + class EvidentlySimpleDistanceDriftMetric { + string distance_metric = "euclidean" + bool bootstrap = False + float quantile_probability = 0.95 + float threshold = 0.2 + } + + BaseMetric <|-- EvidentlyDriftMetric + EvidentlyDriftMetric <|-- EvidentlyModelDriftMetric + EvidentlyDriftMetric <|-- EvidentlyRatioDriftMetric + EvidentlyDriftMetric <|-- EvidentlySimpleDistanceDriftMetric + +``` + +### DecisionCriterion Hierarchy + +The `DecisionCriterion` class is an abstract configuration base class for criteria like `ThresholdDecisionCriterion` and `DynamicThresholdCriterion`, which define how decisions are made based on drift metrics. + +```mermaid +classDiagram + class DecisionCriterion { + <> + } + + class ThresholdDecisionCriterion { + float threshold + } + + class DynamicThresholdCriterion { + int window_size = 10 + float percentile = 0.05 + } + + DecisionCriterion <|-- ThresholdDecisionCriterion + DecisionCriterion <|-- DynamicThresholdCriterion +``` + +### DriftDecisionPolicy Hierarchy + +The `DriftDecisionPolicy` class is an abstract base class for policies like `ThresholdDecisionPolicy`, `DynamicDecisionPolicy`, and `HypothesisTestDecisionPolicy`. + +Each decision policy wraps one DriftMetric (e.g. MMD, CVM, KS) and one DecisionCriterion (e.g. Threshold, DynamicThreshold, HypothesisTest) to make a decision based on the distance metric. It e.g. observes the series of distance value measurements from it's `DriftMetric` and makes a decision after having calibrated on the seen distances. + +If a `DecisionPolicy` needs to be calibrated before being able to make a decision, we have to run the `DriftTrigger` with a warm-up period. This warm-up period is defined as a fixed number of intervals where another simple drift policy is used to make decisions while also evaluating the `DecisionPolicy` to calibrate it. + +
+Dynamic Threshold Calibration + +Warmup intervals are used to calibrate our drift decision policy. While delegating the drift decision to a simple substitute policy, we use the data windows from these calibration time intervals to generate a sequence of drift distances. After finishing the warmup, we can calibrate a dynamic threshold policy on this series. + +To derive these warm-up distances, we don't simply use the reference/current window pairs from every warm-up interval, as one might expect. This approach would correspond to calibrating on the diagonal elements of an offline drift-distance matrix. As one might expect, the diagonal elements have distance values close to zero as they contain data from the same time frames and even the exact same data depending on the windowing setup. + +Hence, we need to calibrate on distance matrix elements other than the diagonal. We chose to do the distance value generation at the end of the warmup period. By then, the full lower-left submatrix will potentially be computable. We then compute the submatrix column of the warmup-end diagonal element. For that, we need to memorize the first |warmup_intervals| reference windows and compute the distance to the fixed latest current window. + +
+ +Within one `DataDriftTrigger` the different results from different `DriftMetrics`'s `DriftDecisionPolicies` can be aggregated to a final decision using a voting mechanism (see `DataDriftTriggerConfig.aggregation_strategy`). + +```mermaid +classDiagram + class DriftDecisionPolicy { + <> + +bool evaluate_decision(float distance) + } + + class ThresholdDecisionPolicy { + +ThresholdDecisionCriterion config + +bool evaluate_decision(float distance) + } + + class DynamicDecisionPolicy { + +DynamicThresholdCriterion config + +Deque~float~ score_observations + +bool evaluate_decision(float distance) + } + + class HypothesisTestDecisionPolicy { + +HypothesisTestCriterion config + +bool evaluate_decision(float distance) + } + + DriftDecisionPolicy <|-- ThresholdDecisionPolicy + DriftDecisionPolicy <|-- DynamicDecisionPolicy + DriftDecisionPolicy <|-- HypothesisTestDecisionPolicy + + + style HypothesisTestDecisionPolicy fill:#DDDDDD,stroke:#A9A9A9,stroke-width:2px +``` + +This architecture provides a flexible framework for implementing various types of data drift detection mechanisms, different detection libraries, each with its own specific configuration, detection strategy, and decision-making criteria. diff --git a/modyn/config/schema/pipeline/trigger/__init__.py b/modyn/config/schema/pipeline/trigger/__init__.py index ff2f6335d..78c0ae572 100644 --- a/modyn/config/schema/pipeline/trigger/__init__.py +++ b/modyn/config/schema/pipeline/trigger/__init__.py @@ -4,16 +4,14 @@ from pydantic import Field -from .data_amount import * # noqa -from .data_amount import DataAmountTriggerConfig from .drift import * # noqa from .drift import DataDriftTriggerConfig from .ensemble import * # noqa from .ensemble import EnsembleTriggerConfig -from .time import * # noqa -from .time import TimeTriggerConfig +from .simple import * # noqa +from .simple import SimpleTriggerConfig TriggerConfig = Annotated[ - TimeTriggerConfig | DataAmountTriggerConfig | DataDriftTriggerConfig | EnsembleTriggerConfig, + SimpleTriggerConfig | DataDriftTriggerConfig | EnsembleTriggerConfig, Field(discriminator="id"), ] diff --git a/modyn/config/schema/pipeline/trigger/drift/__init__.py b/modyn/config/schema/pipeline/trigger/drift/__init__.py index 0671ba7f5..0bbc3fb7a 100644 --- a/modyn/config/schema/pipeline/trigger/drift/__init__.py +++ b/modyn/config/schema/pipeline/trigger/drift/__init__.py @@ -1,5 +1,6 @@ from .aggregation import * # noqa from .alibi_detect import * # noqa from .config import * # noqa +from .detection_window import * # noqa from .evidently import * # noqa from .result import * # noqa diff --git a/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py b/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py index 8dcc2829d..882fec679 100644 --- a/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py +++ b/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py @@ -1,11 +1,16 @@ +# Note: we don't use the hypothesis testing in the alibi-detect metrics. However, we still keep +# the support for it in this wrapper configuration for offline experiments to still be able to +# use the hypothesis testing. + from typing import Annotated, Literal from pydantic import Field, model_validator from modyn.config.schema.base_model import ModynBaseModel +from modyn.config.schema.pipeline.trigger.drift.metric import BaseMetric -class _AlibiDetectBaseDriftMetric(ModynBaseModel): +class _AlibiDetectBaseDriftMetric(BaseMetric): p_val: float = Field(0.05, description="The p-value threshold for the drift detection.") x_ref_preprocessed: bool = Field(False) @@ -36,7 +41,8 @@ class AlibiDetectMmdDriftMetric(_AlibiDetectBaseDriftMetric, AlibiDetectDeviceMi ), ) kernel: str = Field( - "GaussianRBF", description="The kernel used for distance calculation imported from alibi_detect.utils.pytorch" + "GaussianRBF", + description="The kernel used for distance calculation imported from alibi_detect.utils.pytorch", ) configure_kernel_from_x_ref: bool = Field(True) threshold: float | None = Field( @@ -56,13 +62,14 @@ def validate_threshold_permutations(self) -> "AlibiDetectMmdDriftMetric": + "or threshold comparison for making drift decisions." ) - if self.threshold is None and self.num_permutations is None: - raise ValueError("Please specify either threshold or num_permutations") - return self -class AlibiDetectKSDriftMetric(_AlibiDetectBaseDriftMetric, _AlibiDetectAlternativeMixin, _AlibiDetectCorrectionMixin): +class AlibiDetectKSDriftMetric( + _AlibiDetectBaseDriftMetric, + _AlibiDetectAlternativeMixin, + _AlibiDetectCorrectionMixin, +): id: Literal["AlibiDetectKSDriftMetric"] = Field("AlibiDetectKSDriftMetric") diff --git a/modyn/config/schema/pipeline/trigger/drift/config.py b/modyn/config/schema/pipeline/trigger/drift/config.py index 7dec4847f..fe6d99856 100644 --- a/modyn/config/schema/pipeline/trigger/drift/config.py +++ b/modyn/config/schema/pipeline/trigger/drift/config.py @@ -1,13 +1,15 @@ from __future__ import annotations -from functools import cached_property -from typing import Annotated, ForwardRef, Literal +from typing import Annotated, ForwardRef, Literal, Self -from pydantic import Field +from pydantic import Field, model_validator from modyn.config.schema.base_model import ModynBaseModel -from modyn.const.regex import REGEX_TIME_UNIT -from modyn.utils.utils import SECONDS_PER_UNIT +from modyn.config.schema.pipeline.trigger.drift.detection_window import ( + AmountWindowingStrategy, + DriftWindowingStrategy, +) +from modyn.config.schema.pipeline.trigger.simple import SimpleTriggerConfig from .aggregation import DriftAggregationStrategy, MajorityVoteDriftAggregationStrategy from .alibi_detect import AlibiDetectDriftMetric @@ -21,53 +23,54 @@ ] -class AmountWindowingStrategy(ModynBaseModel): - id: Literal["AmountWindowingStrategy"] = Field("AmountWindowingStrategy") - amount: int = Field(1000, description="How many data points should fit in the window") - - -class TimeWindowingStrategy(ModynBaseModel): - id: Literal["TimeWindowingStrategy"] = Field("TimeWindowingStrategy") - limit: str = Field( - description="Window size as an integer followed by a time unit: s, m, h, d, w, y", - pattern=rf"^\d+{REGEX_TIME_UNIT}$", - ) - - @cached_property - def limit_seconds(self) -> int: - unit = str(self.limit)[-1:] - num = int(str(self.limit)[:-1]) - return num * SECONDS_PER_UNIT[unit] - - -DriftWindowingStrategy = Annotated[ - AmountWindowingStrategy | TimeWindowingStrategy, - Field(discriminator="id"), -] - - class DataDriftTriggerConfig(ModynBaseModel): id: Literal["DataDriftTrigger"] = Field("DataDriftTrigger") detection_interval: __TriggerConfig | None = Field( # type: ignore[valid-type] - None, description="The Trigger policy to determine the interval at which drift detection is performed." + None, + description="The Trigger policy to determine the interval at which drift detection is performed.", ) # currently not used - detection_interval_data_points: int = Field( - 1000, description="The number of samples in the interval after which drift detection is performed.", ge=1 + 1000, + description="The number of samples in the interval after which drift detection is performed.", + ge=1, ) + windowing_strategy: DriftWindowingStrategy = Field( - AmountWindowingStrategy(), description="Which windowing strategy to use for current and reference data" + AmountWindowingStrategy(), + description="Which windowing strategy to use for current and reference data", ) - reset_current_window_on_trigger: bool = Field( - False, description="Whether the current window should be reset on trigger or rather be extended." + warmup_intervals: int | None = Field( + None, + description=( + "The number of intervals before starting to use the drift detection. Some " + "`DecisionCriteria` use this to calibrate the threshold. During the warmup, a simpler `warmup_policy` " + "is consulted for the triggering decision." + ), + ) + warmup_policy: SimpleTriggerConfig | None = Field( + None, + description=( + "The policy to use for triggering during the warmup phase of the drift policy. " + "Metrics that don't need calibration can ignore this." + ), ) metrics: dict[str, DriftMetric] = Field( - min_length=1, description="The metrics used for drift detection keyed by a reference." + min_length=1, + description="The metrics used for drift detection keyed by a reference.", ) aggregation_strategy: DriftAggregationStrategy = Field( MajorityVoteDriftAggregationStrategy(), description="The strategy to aggregate the decisions of the individual drift metrics.", ) + + @model_validator(mode="after") + def warmup_policy_requirement(self) -> Self: + """Assert whether the warmup policy is set when a metric needs + calibration.""" + for metric in self.metrics.values(): + if metric.decision_criterion.needs_calibration and self.warmup_policy is None: + raise ValueError("A warmup policy is required for metrics that need calibration.") + return self diff --git a/modyn/config/schema/pipeline/trigger/drift/detection_window/__init__.py b/modyn/config/schema/pipeline/trigger/drift/detection_window/__init__.py new file mode 100644 index 000000000..3b4562fd8 --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/drift/detection_window/__init__.py @@ -0,0 +1,11 @@ +from typing import Annotated + +from pydantic import Field + +from .amount import AmountWindowingStrategy +from .time_ import TimeWindowingStrategy + +DriftWindowingStrategy = Annotated[ + AmountWindowingStrategy | TimeWindowingStrategy, + Field(discriminator="id"), +] diff --git a/modyn/config/schema/pipeline/trigger/drift/detection_window/amount.py b/modyn/config/schema/pipeline/trigger/drift/detection_window/amount.py new file mode 100644 index 000000000..a08f4a33a --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/drift/detection_window/amount.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import Field + +from .window import _BaseWindowingStrategy + + +class AmountWindowingStrategy(_BaseWindowingStrategy): + id: Literal["AmountWindowingStrategy"] = Field("AmountWindowingStrategy") + amount_ref: int = Field( + 1000, + description="How many data points should fit in the reference window", + ge=1, + ) + amount_cur: int = Field(1000, description="How many data points should fit in the current window", ge=1) diff --git a/modyn/config/schema/pipeline/trigger/drift/detection_window/time_.py b/modyn/config/schema/pipeline/trigger/drift/detection_window/time_.py new file mode 100644 index 000000000..d61b738e6 --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/drift/detection_window/time_.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from functools import cached_property +from typing import Literal + +from pydantic import Field + +from modyn.const.regex import REGEX_TIME_UNIT +from modyn.utils.utils import SECONDS_PER_UNIT + +from .window import _BaseWindowingStrategy + + +class TimeWindowingStrategy(_BaseWindowingStrategy): + id: Literal["TimeWindowingStrategy"] = Field("TimeWindowingStrategy") + limit_ref: str = Field( + description="Window size as an integer followed by a time unit: s, m, h, d, w, y", + pattern=rf"^\d+{REGEX_TIME_UNIT}$", + ) + limit_cur: str = Field( + description="Window size as an integer followed by a time unit: s, m, h, d, w, y", + pattern=rf"^\d+{REGEX_TIME_UNIT}$", + ) + + @cached_property + def limit_seconds_ref(self) -> int: + unit = str(self.limit_ref)[-1:] + num = int(str(self.limit_ref)[:-1]) + return num * SECONDS_PER_UNIT[unit] + + @cached_property + def limit_seconds_cur(self) -> int: + unit = str(self.limit_cur)[-1:] + num = int(str(self.limit_cur)[:-1]) + return num * SECONDS_PER_UNIT[unit] diff --git a/modyn/config/schema/pipeline/trigger/drift/detection_window/window.py b/modyn/config/schema/pipeline/trigger/drift/detection_window/window.py new file mode 100644 index 000000000..2b294c1b0 --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/drift/detection_window/window.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from pydantic import Field + +from modyn.config.schema.base_model import ModynBaseModel + + +class _BaseWindowingStrategy(ModynBaseModel): + allow_overlap: bool = Field( + False, + description=( + "Whether the windows are allowed to overlap." + "If set to False, the current window will be reset after each trigger." + ), + ) diff --git a/modyn/config/schema/pipeline/trigger/drift/evidently.py b/modyn/config/schema/pipeline/trigger/drift/evidently.py index ff8ee579f..1a39e6da1 100644 --- a/modyn/config/schema/pipeline/trigger/drift/evidently.py +++ b/modyn/config/schema/pipeline/trigger/drift/evidently.py @@ -2,10 +2,10 @@ from pydantic import Field -from modyn.config.schema.base_model import ModynBaseModel +from modyn.config.schema.pipeline.trigger.drift.metric import BaseMetric -class _EvidentlyBaseDriftMetric(ModynBaseModel): +class _EvidentlyBaseDriftMetric(BaseMetric): num_pca_component: int | None = Field(None) @@ -18,14 +18,20 @@ class EvidentlyModelDriftMetric(_EvidentlyBaseDriftMetric): class EvidentlyRatioDriftMetric(_EvidentlyBaseDriftMetric): id: Literal["EvidentlyRatioDriftMetric"] = Field("EvidentlyRatioDriftMetric") - component_stattest: str = Field("wasserstein", description="The statistical test used to compare the components.") + component_stattest: str = Field( + "wasserstein", + description="The statistical test used to compare the components.", + ) component_stattest_threshold: float = Field(0.1) threshold: float = Field(0.2) class EvidentlySimpleDistanceDriftMetric(_EvidentlyBaseDriftMetric): id: Literal["EvidentlySimpleDistanceDriftMetric"] = Field("EvidentlySimpleDistanceDriftMetric") - distance_metric: str = Field("euclidean", description="The distance metric used for the distance calculation.") + distance_metric: str = Field( + "euclidean", + description="The distance metric used for the distance calculation.", + ) bootstrap: bool = Field(False) quantile_probability: float = 0.95 threshold: float = Field(0.2) diff --git a/modyn/config/schema/pipeline/trigger/drift/metric.py b/modyn/config/schema/pipeline/trigger/drift/metric.py new file mode 100644 index 000000000..edf211093 --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/drift/metric.py @@ -0,0 +1,31 @@ +from typing import Annotated, Literal + +from pydantic import Field + +from modyn.config.schema.base_model import ModynBaseModel + + +class ThresholdDecisionCriterion(ModynBaseModel): + id: Literal["ThresholdDecisionCriterion"] = "ThresholdDecisionCriterion" + threshold: float + needs_calibration: Literal[False] = Field(False) + + +class DynamicThresholdCriterion(ModynBaseModel): + id: Literal["DynamicThresholdCriterion"] = "DynamicThresholdCriterion" + window_size: int = Field(10) + percentile: float = Field( + 0.05, + description="The percentile that a threshold has to be in to trigger a drift event.", + ) + needs_calibration: Literal[True] = Field(True) + + +DecisionCriterion = Annotated[ + ThresholdDecisionCriterion | DynamicThresholdCriterion, + Field(discriminator="id"), +] + + +class BaseMetric(ModynBaseModel): + decision_criterion: DecisionCriterion diff --git a/modyn/config/schema/pipeline/trigger/drift/result.py b/modyn/config/schema/pipeline/trigger/drift/result.py index c628baa33..e91a5e374 100644 --- a/modyn/config/schema/pipeline/trigger/drift/result.py +++ b/modyn/config/schema/pipeline/trigger/drift/result.py @@ -6,6 +6,6 @@ class MetricResult(ModynBaseModel): metric_id: str = Field(description="The id of the metric used for drift detection.") is_drift: bool - p_val: float | list[float] | None = None - distance: float | list[float] + p_val: float | None = None + distance: float threshold: float | None = None diff --git a/modyn/config/schema/pipeline/trigger/simple/__init__.py b/modyn/config/schema/pipeline/trigger/simple/__init__.py new file mode 100644 index 000000000..98faa21e2 --- /dev/null +++ b/modyn/config/schema/pipeline/trigger/simple/__init__.py @@ -0,0 +1,11 @@ +from typing import Annotated + +from pydantic import Field + +from .data_amount import DataAmountTriggerConfig # noqa +from .time import TimeTriggerConfig # noqa + +SimpleTriggerConfig = Annotated[ + TimeTriggerConfig | DataAmountTriggerConfig, + Field(discriminator="id"), +] diff --git a/modyn/config/schema/pipeline/trigger/data_amount.py b/modyn/config/schema/pipeline/trigger/simple/data_amount.py similarity index 100% rename from modyn/config/schema/pipeline/trigger/data_amount.py rename to modyn/config/schema/pipeline/trigger/simple/data_amount.py diff --git a/modyn/config/schema/pipeline/trigger/time.py b/modyn/config/schema/pipeline/trigger/simple/time.py similarity index 100% rename from modyn/config/schema/pipeline/trigger/time.py rename to modyn/config/schema/pipeline/trigger/simple/time.py diff --git a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py index 629c9e0e8..163bc6be8 100644 --- a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py @@ -13,10 +13,22 @@ import pandas as pd from typing_extensions import ParamSpec -from modyn.supervisor.internal.grpc.enums import CounterAction, IdType, MsgType, PipelineStage -from modyn.supervisor.internal.grpc.template_msg import counter_submsg, dataset_submsg, id_submsg, pipeline_stage_msg +from modyn.supervisor.internal.grpc.enums import ( + CounterAction, + IdType, + MsgType, + PipelineStage, +) +from modyn.supervisor.internal.grpc.template_msg import ( + counter_submsg, + dataset_submsg, + id_submsg, + pipeline_stage_msg, +) from modyn.supervisor.internal.grpc_handler import GRPCHandler -from modyn.supervisor.internal.pipeline_executor.evaluation_executor import EvaluationExecutor +from modyn.supervisor.internal.pipeline_executor.evaluation_executor import ( + EvaluationExecutor, +) from modyn.supervisor.internal.pipeline_executor.models import ( ConfigLogs, EvaluateTriggerInfo, @@ -34,8 +46,8 @@ ) from modyn.supervisor.internal.triggers import Trigger from modyn.supervisor.internal.triggers.trigger import TriggerContext +from modyn.supervisor.internal.triggers.utils.factory import instantiate_trigger from modyn.supervisor.internal.utils import TrainingStatusReporter -from modyn.utils import dynamic_module_import from modyn.utils.timer import timed_generator from modyn.utils.utils import current_time_micros @@ -81,7 +93,11 @@ def wrapper_outer( # type: ignore[no-untyped-def] func: Callable[Concatenate[PipelineExecutor, ExecutionState, StageLog, P], R], ): def wrapper( - self: PipelineExecutor, state: ExecutionState, logs: PipelineLogs, *args: P.args, **kwargs: P.kwargs + self: PipelineExecutor, + state: ExecutionState, + logs: PipelineLogs, + *args: P.args, + **kwargs: P.kwargs, ) -> R: """Measures the time for each stage and logs the pipeline state.""" @@ -183,7 +199,11 @@ def __init__(self, options: PipelineExecutionParams) -> None: self.trigger = self._setup_trigger() self.grpc = GRPCHandler(self.state.modyn_config.model_dump(by_alias=True)) self.eval_executor = EvaluationExecutor( - options.pipeline_id, options.pipeline_logdir, options.modyn_config, options.pipeline_config, self.grpc + options.pipeline_id, + options.pipeline_logdir, + options.modyn_config, + options.pipeline_config, + self.grpc, ) def run(self) -> None: @@ -298,7 +318,11 @@ def _fetch_new_data(self, s: ExecutionState, log: StageLog) -> int: The number of triggers that occurred during the processing of the new data. """ s.pipeline_status_queue.put( - pipeline_stage_msg(PipelineStage.FETCH_NEW_DATA, MsgType.DATASET, dataset_submsg(s.dataset_id)) + pipeline_stage_msg( + PipelineStage.FETCH_NEW_DATA, + MsgType.DATASET, + dataset_submsg(s.dataset_id), + ) ) num_samples = 0 @@ -349,10 +373,16 @@ def _wait_for_new_data(self, s: ExecutionState, log: StageLog) -> None: # Process new data @pipeline_stage( - PipelineStage.PROCESS_NEW_DATA, parent=[PipelineStage.REPLAY_DATA, PipelineStage.FETCH_NEW_DATA], track=True + PipelineStage.PROCESS_NEW_DATA, + parent=[PipelineStage.REPLAY_DATA, PipelineStage.FETCH_NEW_DATA], + track=True, ) def _process_new_data( - self, s: ExecutionState, log: StageLog, new_data: list[tuple[int, int, int]], fetch_time: int + self, + s: ExecutionState, + log: StageLog, + new_data: list[tuple[int, int, int]], + fetch_time: int, ) -> list[int]: """Handle new data during experiments or online pipeline serving. @@ -378,7 +408,10 @@ def _process_new_data( pipeline_stage_msg( PipelineStage.PROCESS_NEW_DATA, MsgType.COUNTER, - counter_submsg(CounterAction.CREATE, {"title": "Processing New Samples", "new_data_len": new_data_len}), + counter_submsg( + CounterAction.CREATE, + {"title": "Processing New Samples", "new_data_len": new_data_len}, + ), ) ) @@ -404,18 +437,30 @@ def _process_new_data( break s.pipeline_status_queue.put( - pipeline_stage_msg(PipelineStage.NEW_DATA_HANDLED, MsgType.COUNTER, counter_submsg(CounterAction.CLOSE)) + pipeline_stage_msg( + PipelineStage.NEW_DATA_HANDLED, + MsgType.COUNTER, + counter_submsg(CounterAction.CLOSE), + ) ) # log extra information - log.info = ProcessNewDataInfo(fetch_time=fetch_time, num_samples=new_data_len, trigger_indexes=trigger_indexes) + log.info = ProcessNewDataInfo( + fetch_time=fetch_time, + num_samples=new_data_len, + trigger_indexes=trigger_indexes, + ) self.logs.materialize(s.log_directory, mode="increment") return trigger_indexes # Process new data BATCH - @pipeline_stage(PipelineStage.PROCESS_NEW_DATA_BATCH, parent=PipelineStage.PROCESS_NEW_DATA, track=True) + @pipeline_stage( + PipelineStage.PROCESS_NEW_DATA_BATCH, + parent=PipelineStage.PROCESS_NEW_DATA, + track=True, + ) def _process_new_data_batch(self, s: ExecutionState, log: StageLog, batch: list[tuple[int, int, int]]) -> list[int]: """Process new data in batches and evaluate trigger policies in batches. @@ -440,7 +485,11 @@ def _process_new_data_batch(self, s: ExecutionState, log: StageLog, batch: list[ return handled_triggers - @pipeline_stage(PipelineStage.EVALUATE_TRIGGER_POLICY, parent=PipelineStage.PROCESS_NEW_DATA_BATCH, track=True) + @pipeline_stage( + PipelineStage.EVALUATE_TRIGGER_POLICY, + parent=PipelineStage.PROCESS_NEW_DATA_BATCH, + track=True, + ) def _evaluate_trigger_policy( self, s: ExecutionState, log: StageLog, batch: list[tuple[int, int, int]] ) -> Generator[int, None, None]: @@ -463,7 +512,11 @@ def _evaluate_trigger_policy( assert log.info logger.info(f"There were {len(log.info.trigger_indexes)} triggers in this batch.") - @pipeline_stage(PipelineStage.HANDLE_TRIGGERS, parent=PipelineStage.PROCESS_NEW_DATA_BATCH, track=True) + @pipeline_stage( + PipelineStage.HANDLE_TRIGGERS, + parent=PipelineStage.PROCESS_NEW_DATA_BATCH, + track=True, + ) def _handle_triggers( self, s: ExecutionState, @@ -506,10 +559,16 @@ def _handle_triggers( return trigger_indexes @pipeline_stage( - PipelineStage.INFORM_SELECTOR_REMAINING_DATA, parent=PipelineStage.PROCESS_NEW_DATA_BATCH, track=True + PipelineStage.INFORM_SELECTOR_REMAINING_DATA, + parent=PipelineStage.PROCESS_NEW_DATA_BATCH, + track=True, ) def _inform_selector_remaining_data( - self, s: ExecutionState, log: StageLog, batch: list[tuple[int, int, int]], trigger_indexes: list[int] + self, + s: ExecutionState, + log: StageLog, + batch: list[tuple[int, int, int]], + trigger_indexes: list[int], ) -> None: """Inform selector about remaining data.""" @@ -528,21 +587,33 @@ def _inform_selector_remaining_data( selector_log = self.grpc.inform_selector(s.pipeline_id, s.remaining_data) if s.remaining_data_range is not None: # extend the range from last time - s.remaining_data_range = (s.remaining_data_range[0], s.remaining_data[-1][1]) + s.remaining_data_range = ( + s.remaining_data_range[0], + s.remaining_data[-1][1], + ) else: - s.remaining_data_range = (s.remaining_data[0][1], s.remaining_data[-1][1]) + s.remaining_data_range = ( + s.remaining_data[0][1], + s.remaining_data[-1][1], + ) else: selector_log = None s.remaining_data_range = None # add log data log.info = SelectorInformInfo( - selector_log=selector_log, remaining_data=len(s.remaining_data) > 0, trigger_indexes=trigger_indexes + selector_log=selector_log, + remaining_data=len(s.remaining_data) > 0, + trigger_indexes=trigger_indexes, ) # Handle trigger within batch - @pipeline_stage(PipelineStage.HANDLE_SINGLE_TRIGGER, parent=PipelineStage.HANDLE_TRIGGERS, track=True) + @pipeline_stage( + PipelineStage.HANDLE_SINGLE_TRIGGER, + parent=PipelineStage.HANDLE_TRIGGERS, + track=True, + ) def _handle_single_trigger( self, s: ExecutionState, @@ -574,7 +645,13 @@ def _handle_single_trigger( if s.pipeline_config.evaluation: self._evaluate_and_store_results( - s, self.logs, trigger_id, training_id, model_id, first_timestamp, last_timestamp + s, + self.logs, + trigger_id, + training_id, + model_id, + first_timestamp, + last_timestamp, ) else: @@ -589,7 +666,11 @@ def _handle_single_trigger( last_timestamp=last_timestamp, ) - @pipeline_stage(PipelineStage.INFORM_SELECTOR_ABOUT_TRIGGER, parent=PipelineStage.HANDLE_SINGLE_TRIGGER, track=True) + @pipeline_stage( + PipelineStage.INFORM_SELECTOR_ABOUT_TRIGGER, + parent=PipelineStage.HANDLE_SINGLE_TRIGGER, + track=True, + ) def _inform_selector_about_trigger( self, s: ExecutionState, @@ -623,7 +704,11 @@ def _inform_selector_about_trigger( # Training - @pipeline_stage(PipelineStage.TRAIN_AND_STORE_MODEL, parent=PipelineStage.HANDLE_SINGLE_TRIGGER, track=True) + @pipeline_stage( + PipelineStage.TRAIN_AND_STORE_MODEL, + parent=PipelineStage.HANDLE_SINGLE_TRIGGER, + track=True, + ) def _train_and_store_model(self, s: ExecutionState, log: StageLog, trigger_id: int) -> tuple[int, int]: """Train a new model on batch data and store it.""" @@ -634,7 +719,11 @@ def _train_and_store_model(self, s: ExecutionState, log: StageLog, trigger_id: i s.trained_models.append(model_id) s.pipeline_status_queue.put( - pipeline_stage_msg(PipelineStage.HANDLE_TRIGGERS, MsgType.ID, id_submsg(IdType.TRIGGER, trigger_id)) + pipeline_stage_msg( + PipelineStage.HANDLE_TRIGGERS, + MsgType.ID, + id_submsg(IdType.TRIGGER, trigger_id), + ) ) return training_id, model_id @@ -665,7 +754,11 @@ def _train(self, s: ExecutionState, log: StageLog, trigger_id: int) -> int: total_samples = self.grpc.get_number_of_samples(s.pipeline_id, trigger_id) status_bar_scale = self.grpc.get_status_bar_scale(s.pipeline_id) training_reporter = TrainingStatusReporter( - self.state.training_status_queue, trigger_id, s.current_training_id, total_samples, status_bar_scale + self.state.training_status_queue, + trigger_id, + s.current_training_id, + total_samples, + status_bar_scale, ) trainer_log = self.grpc.wait_for_training_completion(s.current_training_id, training_reporter) @@ -679,20 +772,35 @@ def _train(self, s: ExecutionState, log: StageLog, trigger_id: int) -> int: return s.current_training_id - @pipeline_stage(PipelineStage.TRAINING_COMPLETED, parent=PipelineStage.TRAIN_AND_STORE_MODEL, track=False) + @pipeline_stage( + PipelineStage.TRAINING_COMPLETED, + parent=PipelineStage.TRAIN_AND_STORE_MODEL, + track=False, + ) def _training_completed(self, s: ExecutionState, log: StageLog, training_id: int) -> None: s.pipeline_status_queue.put( pipeline_stage_msg( - PipelineStage.TRAINING_COMPLETED, MsgType.ID, id_submsg(IdType.TRAINING, training_id), True + PipelineStage.TRAINING_COMPLETED, + MsgType.ID, + id_submsg(IdType.TRAINING, training_id), + True, ) ) logger.info(f"Training {training_id} completed") - @pipeline_stage(PipelineStage.STORE_TRAINED_MODEL, parent=PipelineStage.TRAIN_AND_STORE_MODEL, track=True) + @pipeline_stage( + PipelineStage.STORE_TRAINED_MODEL, + parent=PipelineStage.TRAIN_AND_STORE_MODEL, + track=True, + ) def _store_trained_model(self, s: ExecutionState, log: StageLog, trigger_id: int, training_id: int) -> int: """Stores a trained model and returns the model id.""" s.pipeline_status_queue.put( - pipeline_stage_msg(PipelineStage.STORE_TRAINED_MODEL, MsgType.ID, id_submsg(IdType.TRIGGER, trigger_id)) + pipeline_stage_msg( + PipelineStage.STORE_TRAINED_MODEL, + MsgType.ID, + id_submsg(IdType.TRIGGER, trigger_id), + ) ) model_id = self.grpc.store_trained_model(training_id) @@ -721,7 +829,11 @@ def _evaluate_and_store_results( ) -> None: """Evaluate the trained model and store the results.""" s.pipeline_status_queue.put( - pipeline_stage_msg(PipelineStage.EVALUATE, MsgType.ID, id_submsg(IdType.TRIGGER, trigger_id)) + pipeline_stage_msg( + PipelineStage.EVALUATE, + MsgType.ID, + id_submsg(IdType.TRIGGER, trigger_id), + ) ) logs = self.eval_executor.run_pipeline_evaluations( log, @@ -740,7 +852,12 @@ def _done(self, s: ExecutionState, log: StageLog) -> None: s.pipeline_status_queue.put(pipeline_stage_msg(PipelineStage.DONE, MsgType.GENERAL)) self.logs.pipeline_stages = _pipeline_stage_parents # now includes chronology info - @pipeline_stage(PipelineStage.POST_EVALUATION_CHECKPOINT, parent=PipelineStage.MAIN, log=False, track=False) + @pipeline_stage( + PipelineStage.POST_EVALUATION_CHECKPOINT, + parent=PipelineStage.MAIN, + log=False, + track=False, + ) def _post_pipeline_evaluation_checkpoint(self, s: ExecutionState, log: StageLog) -> None: """Stores evaluation relevant information so that the evaluator can be started on this pipeline run again.""" @@ -770,13 +887,7 @@ def _exit(self, s: ExecutionState, log: StageLog) -> None: # setup def _setup_trigger(self) -> Trigger: - trigger_id = self.state.pipeline_config.trigger.id - trigger_config = self.state.pipeline_config.trigger - - trigger_module = dynamic_module_import("modyn.supervisor.internal.triggers") - trigger: Trigger = getattr(trigger_module, trigger_id)(trigger_config) - assert trigger is not None, "Error during trigger initialization" - + trigger = instantiate_trigger(self.state.pipeline_config.trigger.id, self.state.pipeline_config.trigger) trigger.init_trigger( TriggerContext( pipeline_id=self.state.pipeline_id, @@ -794,7 +905,9 @@ def _setup_trigger(self) -> Trigger: @staticmethod def _get_trigger_timespan( - s: ExecutionState, is_first_trigger_data: bool, trigger_data: list[tuple[int, int, int]] + s: ExecutionState, + is_first_trigger_data: bool, + trigger_data: list[tuple[int, int, int]], ) -> tuple[int, int]: if is_first_trigger_data: # now it is the first trigger in this batch. Triggering_data can be empty. diff --git a/modyn/supervisor/internal/triggers/datadrifttrigger.py b/modyn/supervisor/internal/triggers/datadrifttrigger.py index 8ba88fea7..7675c23bb 100644 --- a/modyn/supervisor/internal/triggers/datadrifttrigger.py +++ b/modyn/supervisor/internal/triggers/datadrifttrigger.py @@ -1,19 +1,47 @@ from __future__ import annotations +import gc import logging from collections.abc import Generator from modyn.config.schema.pipeline import DataDriftTriggerConfig +from modyn.config.schema.pipeline.trigger.drift.detection_window import ( + AmountWindowingStrategy, + DriftWindowingStrategy, + TimeWindowingStrategy, +) +from modyn.config.schema.pipeline.trigger.drift.metric import ThresholdDecisionCriterion from modyn.config.schema.pipeline.trigger.drift.result import MetricResult -from modyn.supervisor.internal.triggers.drift.alibi_detector import AlibiDriftDetector -from modyn.supervisor.internal.triggers.drift.evidently_detector import EvidentlyDriftDetector -from modyn.supervisor.internal.triggers.embedding_encoder_utils import EmbeddingEncoder, EmbeddingEncoderDownloader - -# pylint: disable-next=no-name-in-module -from modyn.supervisor.internal.triggers.models import DriftTriggerEvalLog, TriggerPolicyEvaluationLog +from modyn.supervisor.internal.triggers.drift.decision_policy import ( + DriftDecisionPolicy, + DynamicDecisionPolicy, + ThresholdDecisionPolicy, +) +from modyn.supervisor.internal.triggers.drift.detection_window.amount import ( + AmountDetectionWindows, +) +from modyn.supervisor.internal.triggers.drift.detection_window.time_ import ( + TimeDetectionWindows, +) +from modyn.supervisor.internal.triggers.drift.detection_window.window import ( + DetectionWindows, +) +from modyn.supervisor.internal.triggers.drift.detector.alibi import AlibiDriftDetector +from modyn.supervisor.internal.triggers.drift.detector.evidently import ( + EvidentlyDriftDetector, +) +from modyn.supervisor.internal.triggers.embedding_encoder_utils import ( + EmbeddingEncoder, + EmbeddingEncoderDownloader, +) +from modyn.supervisor.internal.triggers.models import ( + DriftTriggerEvalLog, + TriggerPolicyEvaluationLog, +) from modyn.supervisor.internal.triggers.trigger import Trigger, TriggerContext from modyn.supervisor.internal.triggers.trigger_datasets import DataLoaderInfo -from modyn.supervisor.internal.triggers.utils import ( +from modyn.supervisor.internal.triggers.utils.factory import instantiate_trigger +from modyn.supervisor.internal.triggers.utils.utils import ( convert_tensor_to_df, get_embeddings, prepare_trigger_dataloader_fixed_keys, @@ -36,63 +64,39 @@ def __init__(self, config: DataDriftTriggerConfig): self.encoder_downloader: EmbeddingEncoderDownloader | None = None self.embedding_encoder: EmbeddingEncoder | None = None - self._reference_window: list[tuple[int, int]] = [] - self._current_window: list[tuple[int, int]] = [] - self._total_items_in_current_window = 0 + self._sample_left_until_detection = ( + config.detection_interval_data_points + ) # allows to detect drift in a fixed interval + self._windows = _setup_detection_windows(config.windowing_strategy) self._triggered_once = False self.evidently_detector = EvidentlyDriftDetector(config.metrics) self.alibi_detector = AlibiDriftDetector(config.metrics) - def init_trigger(self, context: TriggerContext) -> None: - self.context = context - self._init_dataloader_info() - self._init_encoder_downloader() + # Every decision policy wraps one metric and is responsible for making decisions based on the metric's results + # and the metric's range of distance values + self.decision_policies = _setup_decision_policies(config) - def _update_curr_window(self, new_data: list[tuple[int, int]]) -> None: - self._current_window.extend(new_data) - self._total_items_in_current_window += len(new_data) - - if self.config.windowing_strategy.id == "AmountWindowingStrategy": - if len(self._current_window) > self.config.windowing_strategy.amount: - items_to_remove = len(self._current_window) - self.config.windowing_strategy.amount - self._current_window = self._current_window[items_to_remove:] - elif self.config.windowing_strategy.id == "TimeWindowingStrategy": - highest_timestamp = new_data[-1][1] - cutoff = highest_timestamp - self.config.windowing_strategy.limit_seconds - self._current_window = [(key, timestamp) for key, timestamp in self._current_window if timestamp >= cutoff] - else: - raise NotImplementedError(f"{self.config.windowing_strategy.id} is not implemented!") + # [WARMUP CONFIGURATION] + self.warmup_completed = config.warmup_policy is None - def _handle_drift_result( - self, - triggered: bool, - trigger_idx: int, - drift_results: dict[str, MetricResult], - log: TriggerPolicyEvaluationLog | None = None, - ) -> Generator[int, None, None]: - drift_eval_log = DriftTriggerEvalLog( - detection_idx_start=self._current_window[0][1], - detection_idx_end=self._current_window[-1][1], - triggered=triggered, - trigger_index=-1, - drift_results=drift_results, + # warmup policy (used as drop in replacement for the yet uncalibrated drift policy) + self.warmup_trigger = ( + instantiate_trigger(config.warmup_policy.id, config.warmup_policy) if config.warmup_policy else None ) - if triggered: - self._reference_window = self._current_window # Current assumption: same windowing strategy on both - self._current_window = [] if self.config.reset_current_window_on_trigger else self._current_window - self._total_items_in_current_window = ( - 0 if self.config.reset_current_window_on_trigger else self._total_items_in_current_window - ) + # list of reference windows for each warmup interval + self.warmup_intervals: list[list[tuple[int, int]]] = [] - if log: - log.evaluations.append(drift_eval_log) - - yield trigger_idx + def init_trigger(self, context: TriggerContext) -> None: + self.context = context + self._init_dataloader_info() + self._init_encoder_downloader() def inform( - self, new_data: list[tuple[int, int, int]], log: TriggerPolicyEvaluationLog | None = None + self, + new_data: list[tuple[int, int, int]], + log: TriggerPolicyEvaluationLog | None = None, ) -> Generator[int, None, None]: """Analyzes a batch of new data to determine if data drift has occurred and triggers retraining if necessary. @@ -106,7 +110,7 @@ def inform( The method works as follows: 1. Extract keys and timestamps from the incoming data points. - 2. Determine the offset, which is the number of data points in the current window that have not yet contributed + 2. Use the offset, which is the number of data points in the current window that have not yet contributed to a drift detection. 3. If the sum of the offset and the length of the new data is less than the detection interval, update the current window with the new data and return without performing drift detection. @@ -127,50 +131,159 @@ def inform( The index of the data point that triggered the drift detection. This is used to identify the point in the data stream where the model's performance may have started to degrade due to drift. """ + # pylint: disable=too-many-nested-blocks new_key_ts = [(key, timestamp) for key, timestamp, _ in new_data] - detect_interval = self.config.detection_interval_data_points - offset = self._total_items_in_current_window % detect_interval - - if offset + len(new_key_ts) < detect_interval: - # No detection in this trigger - self._update_curr_window(new_key_ts) - return - - # At least one detection, fill up window up to that detection - self._update_curr_window(new_key_ts[: detect_interval - offset]) - new_key_ts = new_key_ts[detect_interval - offset :] - trigger_idx = detect_interval - offset - 1 # If we trigger, it will be on this index - - if not self._triggered_once: - # If we've never triggered before, always trigger - self._triggered_once = True - triggered = True - drift_results: dict[str, MetricResult] = {} - else: - # Run the detection - triggered, drift_results = self._run_detection() - - yield from self._handle_drift_result(triggered, trigger_idx, drift_results, log=log) - # Go through remaining data in new data in batches of `detect_interval` - for i in range(0, len(new_key_ts), detect_interval): - batch = new_key_ts[i : i + detect_interval] - trigger_idx += detect_interval - self._update_curr_window(batch) + # index of the first unprocessed data point in the batch + processing_head_in_batch = 0 - if len(batch) == detect_interval: - # Regular batch, in this case run detection - triggered, drift_results = self._run_detection() - yield from self._handle_drift_result(triggered, trigger_idx, drift_results, log=log) + # Go through remaining data in new data in batches of `detect_interval` + while True: + if self._sample_left_until_detection - len(new_key_ts) > 0: + # No detection in this trigger because of too few data points to fill detection interval + self._windows.inform_data(new_key_ts) # update current window + self._sample_left_until_detection -= len(new_key_ts) + return + + # At least one detection, fill up window up to that detection + next_detection_interval = new_key_ts[: self._sample_left_until_detection] + self._windows.inform_data(next_detection_interval) + + # Update the remaining data + processing_head_in_batch += len(next_detection_interval) + new_key_ts = new_key_ts[len(next_detection_interval) :] + + # Reset for next detection + self._sample_left_until_detection = self.config.detection_interval_data_points + + if (not self._triggered_once) or ( + not self.warmup_completed and len(self.warmup_intervals) < (self.config.warmup_intervals or 0) + ): + # Warmup trigger, evaluate warmup policy while storing the reference window + # for later calibration with the drift policy. + + # delegate to the warmup policy + delegated_trigger_results = ( + next( + self.warmup_trigger.inform([(idx, time, 0) for (idx, time) in next_detection_interval]), + None, + ) + if self.warmup_trigger + else None + ) + triggered = ( + delegated_trigger_results is not None + if self._triggered_once + else True # first candidate always triggers + ) + + drift_results: dict[str, MetricResult] = {} + if len(self.warmup_intervals) < (self.config.warmup_intervals or 0): + # for the first detection the reference window is empty, therefore adding the current window + self.warmup_intervals.append( + list(self._windows.reference if self._triggered_once else self._windows.current) + ) + + self._triggered_once = True + + else: + # Run the detection + + # if this is the first non warmup detection, we inform the metrics that use decision criteria + # with calibration requirements about the warmup intervals so they can calibrate their thresholds + if len(self.warmup_intervals) > 0 or not self.warmup_completed: + # we can ignore the results as the decision criteria will keep track of the warmup results + # internally + if self._any_metric_needs_calibration(): + for warmup_interval in self.warmup_intervals: + # we generate the calibration with different reference windows, the latest model and + # the current window + _warmup_triggered, _warmup_results = self._run_detection( + warmup_interval, + list(self._windows.current), + is_warmup=True, + ) + if log: + warmup_log = DriftTriggerEvalLog( + detection_interval=( + self._windows.current[0][1], + self._windows.current[-1][1], + ), + reference_interval=( + self._windows.reference[0][1], + self._windows.reference[-1][1], + ), + triggered=_warmup_triggered, + trigger_index=-1, + drift_results=_warmup_results, + ) + log.evaluations.append(warmup_log) + + # free the memory, but keep filled + self.warmup_completed = True + self.warmup_intervals = [] + gc.collect() + + triggered, drift_results = self._run_detection( + list(self._windows.reference), + list(self._windows.current), + is_warmup=False, + ) + + trigger_idx = processing_head_in_batch - 1 + yield from self._handle_drift_result( + triggered, + trigger_idx, + drift_results, + warmup=not self.warmup_completed, + log=log, + ) def inform_previous_model(self, previous_model_id: int) -> None: self.previous_model_id = previous_model_id self.model_updated = True - # --------------------------------------------------- INTERNAL --------------------------------------------------- # + # ---------------------------------------------------------------------------------------------------------------- # + # INTERNAL # + # ---------------------------------------------------------------------------------------------------------------- # + + def _handle_drift_result( + self, + triggered: bool, + trigger_idx: int, + drift_results: dict[str, MetricResult], + warmup: bool = False, + log: TriggerPolicyEvaluationLog | None = None, + ) -> Generator[int, None, None]: + drift_eval_log = DriftTriggerEvalLog( + detection_interval=( + self._windows.current[0][1], + self._windows.current[-1][1], + ), + reference_interval=( + (self._windows.reference[0][1], self._windows.reference[-1][1]) if self._windows.reference else (-1, -1) + ), + triggered=triggered, + trigger_index=-1, + drift_results=drift_results, + ) + if log: + log.evaluations.append(drift_eval_log) + + if triggered or warmup: + # during the warmup phase we always want to reset the windows as if we detected drift + self._windows.inform_trigger() + + if triggered: + yield trigger_idx - def _run_detection(self) -> tuple[bool, dict[str, MetricResult]]: + def _run_detection( + self, + reference: list[tuple[int, int]], + current: list[tuple[int, int]], + is_warmup: bool, + ) -> tuple[bool, dict[str, MetricResult]]: """Compare current data against reference data. current data: all untriggered samples in the sliding window in inform(). @@ -182,16 +295,14 @@ def _run_detection(self) -> tuple[bool, dict[str, MetricResult]]: assert self.dataloader_info is not None assert self.encoder_downloader is not None assert self.context and self.context.pipeline_config is not None - assert len(self._reference_window) > 0 - assert len(self._current_window) > 0 + assert len(reference) > 0 + assert len(current) > 0 reference_dataloader = prepare_trigger_dataloader_fixed_keys( - self.dataloader_info, [key for key, _ in self._reference_window] + self.dataloader_info, [key for key, _ in reference] ) - current_dataloader = prepare_trigger_dataloader_fixed_keys( - self.dataloader_info, [key for key, _ in self._current_window] - ) + current_dataloader = prepare_trigger_dataloader_fixed_keys(self.dataloader_info, [key for key, _ in current]) # Download previous model as embedding encoder # TODO(417) Support custom model as embedding encoder @@ -203,16 +314,28 @@ def _run_detection(self) -> tuple[bool, dict[str, MetricResult]]: # Compute embeddings assert self.embedding_encoder is not None + + # TODO(@robinholzi): reuse the embeddings as long as the reference window is not updated reference_embeddings = get_embeddings(self.embedding_encoder, reference_dataloader) current_embeddings = get_embeddings(self.embedding_encoder, current_dataloader) reference_embeddings_df = convert_tensor_to_df(reference_embeddings, "col_") current_embeddings_df = convert_tensor_to_df(current_embeddings, "col_") drift_results = { - **self.evidently_detector.detect_drift(reference_embeddings_df, current_embeddings_df), - **self.alibi_detector.detect_drift(reference_embeddings, current_embeddings), + **self.evidently_detector.detect_drift(reference_embeddings_df, current_embeddings_df, is_warmup), + **self.alibi_detector.detect_drift(reference_embeddings, current_embeddings, is_warmup), } + + # make the final decisions with the decision policies + for metric_name, metric_result in drift_results.items(): + # overwrite the raw decision from the metric that is not of interest to us. + drift_results[metric_name].is_drift = self.decision_policies[metric_name].evaluate_decision( + metric_result.distance + ) + logger.info(f"[DataDriftDetector][Dataset {self.dataloader_info.dataset_id}]" + f"[Result] {drift_results}") + if is_warmup: + return False, {} # aggregate the different drift detection results drift_detected = self.config.aggregation_strategy.aggregate_decision_func(drift_results) @@ -249,3 +372,32 @@ def _init_encoder_downloader(self) -> None: self.context.base_dir, f"{self.context.modyn_config.modyn_model_storage.address}", ) + + def _any_metric_needs_calibration(self) -> bool: + return any(metric.decision_criterion.needs_calibration for metric in self.config.metrics.values()) + + +def _setup_detection_windows( + windowing_strategy: DriftWindowingStrategy, +) -> DetectionWindows: + if isinstance(windowing_strategy, AmountWindowingStrategy): + return AmountDetectionWindows(windowing_strategy) + if isinstance(windowing_strategy, TimeWindowingStrategy): + return TimeDetectionWindows(windowing_strategy) + raise ValueError(f"Unsupported windowing strategy: {windowing_strategy}") + + +def _setup_decision_policies( + config: DataDriftTriggerConfig, +) -> dict[str, DriftDecisionPolicy]: + policies: dict[str, DriftDecisionPolicy] = {} + for metric_name, metric_config in config.metrics.items(): + criterion = metric_config.decision_criterion + assert ( + metric_config.num_permutations is None + ), "Modyn doesn't allow hypothesis testing, it doesn't work in our context" + if isinstance(criterion, ThresholdDecisionCriterion): + policies[metric_name] = ThresholdDecisionPolicy(config) + elif isinstance(criterion, DynamicDecisionPolicy): + policies[metric_name] = DynamicDecisionPolicy(config) + return policies diff --git a/modyn/supervisor/internal/triggers/drift/decision_policy.py b/modyn/supervisor/internal/triggers/drift/decision_policy.py new file mode 100644 index 000000000..959804068 --- /dev/null +++ b/modyn/supervisor/internal/triggers/drift/decision_policy.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from collections import deque + +from modyn.config.schema.pipeline.trigger.drift.metric import ( + DynamicThresholdCriterion, + ThresholdDecisionCriterion, +) + + +class DriftDecisionPolicy(ABC): + """Decision policy that will make the binary is_drift decisions based on + the similarity/distance metrics. + + Each drift decision wraps one DriftMetric and observes its time + series of distance values. + """ + + @abstractmethod + def evaluate_decision(self, distance: float) -> bool: + """Evaluate the decision based on the distance value or the raw + is_drift decision. + + Args: + distance: The distance value of the metric. + + Returns: + The final is_drift decision. + """ + + +class ThresholdDecisionPolicy(DriftDecisionPolicy): + """Decision policy that will make the binary is_drift decisions based on a + threshold.""" + + def __init__(self, config: ThresholdDecisionCriterion): + self.config = config + + def evaluate_decision(self, distance: float) -> bool: + return distance >= self.config.threshold + + +class DynamicDecisionPolicy(DriftDecisionPolicy): + """Decision policy that will make the binary is_drift decisions based on a + dynamic threshold. + + We compare a new distance value with the series of previous distance values + and decide if it's more extreme than a certain percentile of the series. Therefore we count the + `num_more_extreme` values that are greater than the new distance and compare it with the + `percentile` threshold. + + TODO: we might want to also support some rolling average policy that will trigger if a distance is deviates + from the average by a certain amount. + """ + + def __init__(self, config: DynamicThresholdCriterion): + self.config = config + self.score_observations: deque = deque(maxlen=self.config.window_size) + + def evaluate_decision(self, distance: float) -> bool: + num_more_extreme = sum(1 for score in self.score_observations if score >= distance) + trigger = True + if len(self.score_observations) > 0: + perc = num_more_extreme / len(self.score_observations) + trigger = perc < self.config.percentile + + self.score_observations.append(distance) + return trigger diff --git a/modyn/supervisor/internal/triggers/drift/detection_window/__init__.py b/modyn/supervisor/internal/triggers/drift/detection_window/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modyn/supervisor/internal/triggers/drift/detection_window/amount.py b/modyn/supervisor/internal/triggers/drift/detection_window/amount.py new file mode 100644 index 000000000..d4bf8d8fb --- /dev/null +++ b/modyn/supervisor/internal/triggers/drift/detection_window/amount.py @@ -0,0 +1,68 @@ +from collections import deque + +from modyn.config.schema.pipeline.trigger.drift.detection_window import ( + AmountWindowingStrategy, +) + +from .window import DetectionWindows + + +class AmountDetectionWindows(DetectionWindows): + def __init__(self, config: AmountWindowingStrategy): + super().__init__() + self.config = config + + # using maxlen the deque will automatically remove the oldest elements if the buffers are full + self.current: deque[tuple[int, int]] = deque(maxlen=config.amount_cur) + + # If the reference window is bigger than the current window, we need a reservoir to store the + # pushed out elements from the current window as we might still need them for the reference window. If a + # trigger happens the whole reservoir and the current window will be copied/moved to the reference window. + # Therefore the reservoir should be the size of the difference between the reference and current window. + self.current_reservoir: deque[tuple[int, int]] = deque(maxlen=max(0, config.amount_ref - config.amount_cur)) + + self.reference: deque[tuple[int, int]] = deque(maxlen=config.amount_ref) + + # In overlapping mode, we need a dedicated buffer to track new samples that are not yet in the reference buffer. + # The current_ and current_reservoir_ are insufficient because, after a trigger, they will contain the same elements as before, + # which prevents us from copying the current elements to the reference buffer without creating duplicates. + # `exclusive_current` contains exactly the new elements that are not yet in the reference buffer. + self.exclusive_current: deque[tuple[int, int]] = deque( + maxlen=config.amount_ref if self.config.allow_overlap else 0 + ) + + def inform_data(self, data: list[tuple[int, int]]) -> None: + assert self.config.amount_cur + + if self.config.allow_overlap: + # use the dedicated buffer that tracks the new elements to be copied to reference on trigger + self.exclusive_current.extend(data) + else: + # use the existing buffers + remaining_pushes = len(self.current) + len(data) - self.config.amount_cur + + # move data from current window to reservoir by first copying the oldest elements in the reservoir + # and then later extending the current window with the new data automatically removing the oldest elements + for pushed_out in self.current: + if remaining_pushes == 0: + break + self.current_reservoir.append(pushed_out) + remaining_pushes -= 1 + + self.current.extend(data) + + def inform_trigger(self) -> None: + if self.config.allow_overlap: + # move all new elements to the reference buffer + self.reference.extend(self.exclusive_current) + self.exclusive_current.clear() + + else: + # First, move data from the reservoir window to the reference window + self.reference.extend(self.current_reservoir) + self.current_reservoir.clear() + + # Move data from current to reference window + self.reference.extend(self.current) + + self.current.clear() diff --git a/modyn/supervisor/internal/triggers/drift/detection_window/time_.py b/modyn/supervisor/internal/triggers/drift/detection_window/time_.py new file mode 100644 index 000000000..767d8971c --- /dev/null +++ b/modyn/supervisor/internal/triggers/drift/detection_window/time_.py @@ -0,0 +1,73 @@ +from collections import deque + +from modyn.config.schema.pipeline.trigger.drift.detection_window import ( + TimeWindowingStrategy, +) + +from .window import DetectionWindows + + +class TimeDetectionWindows(DetectionWindows): + def __init__(self, config: TimeWindowingStrategy): + super().__init__() + self.config = config + + # in overlapping mode (we need dedicated buffer to keep track of the new samples that are not in + # the reference buffer, yet). The current_ and current_reservoir_ are not enough as after + # a trigger they will contain the same elements as before hindering us from copying the + # current elements to the reference buffer (creating duplicates) + self.exclusive_current: deque[tuple[int, int]] = deque() + + def inform_data(self, data: list[tuple[int, int]]) -> None: + if not data: + return + + last_time = data[-1][1] + + # First, add the data to the current window, nothing will be pushed out automatically as there's no buffer limit + self.current.extend(data) + + if self.config.allow_overlap: + # now, pop the data that is too old from the current window. + # This assumes that the data is sorted by timestamp. + while self.current and self.current[0][1] < last_time - self.config.limit_seconds_cur: + self.current.popleft() + + self.exclusive_current.extend(data) + + # pop the data that is not in the reference scope anymore + while self.exclusive_current and self.exclusive_current[0][1] < last_time - self.config.limit_seconds_ref: + self.exclusive_current.popleft() + + else: + # now, pop the data that is too old from the current window and move it to the reservoir. + # This assumes that the data is sorted by timestamp. + while self.current and self.current[0][1] < last_time - self.config.limit_seconds_cur: + self.current_reservoir.append(self.current.popleft()) + + # next, we drop the data from the reservoir that is too old (forget them completely) + while self.current_reservoir and self.current_reservoir[0][1] < last_time - self.config.limit_seconds_ref: + self.current_reservoir.popleft() + + def inform_trigger(self) -> None: + if self.config.allow_overlap: + # move all new elements to the reference buffer + self.reference.extend(self.exclusive_current) + self.exclusive_current.clear() + + else: + # First, move data from the reservoir window to the reference window + self.reference.extend(self.current_reservoir) + self.current_reservoir.clear() + + # Move data from current to reference window + self.reference.extend(self.current) + self.current.clear() + + # now, we drop the data from the reference window that is too old (forget them completely) + if not self.reference: + return + + last_time = self.reference[-1][1] + while self.reference and self.reference[0][1] < last_time - self.config.limit_seconds_ref: + self.reference.popleft() diff --git a/modyn/supervisor/internal/triggers/drift/detection_window/window.py b/modyn/supervisor/internal/triggers/drift/detection_window/window.py new file mode 100644 index 000000000..4b883eba2 --- /dev/null +++ b/modyn/supervisor/internal/triggers/drift/detection_window/window.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from collections import deque + + +class DetectionWindows(ABC): + """Wrapper and manager for the drift detection windows include reference, + current and reservoir window. + + All windows contain tuples with (sample_id, timestamp). + + This class is responsible for the following tasks: + - Keep track of the current and reference window + - Update the current window with new data + - Move data from the current and reservoir window to the reference window + + If the reference window is bigger than the current window, we still want to fill up the whole reference window + after a trigger by taking |reference| elements from the current window. + We therefore need to keep track of the elements that exceed the current window but should still be transferred + to the reference window. If something is is dropped from the reservoir, it won't ever be used again in any window. + """ + + def __init__(self) -> None: + self.current: deque[tuple[int, int]] = deque() + self.current_reservoir: deque[tuple[int, int]] = deque() + self.reference: deque[tuple[int, int]] = deque() + + @abstractmethod + def inform_data(self, data: list[tuple[int, int]]) -> None: ... + + @abstractmethod + def inform_trigger(self) -> None: ... diff --git a/modyn/supervisor/internal/triggers/drift/detector/__init__.py b/modyn/supervisor/internal/triggers/drift/detector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modyn/supervisor/internal/triggers/drift/alibi_detector.py b/modyn/supervisor/internal/triggers/drift/detector/alibi.py similarity index 81% rename from modyn/supervisor/internal/triggers/drift/alibi_detector.py rename to modyn/supervisor/internal/triggers/drift/detector/alibi.py index 531afed98..d03880838 100644 --- a/modyn/supervisor/internal/triggers/drift/alibi_detector.py +++ b/modyn/supervisor/internal/triggers/drift/detector/alibi.py @@ -4,12 +4,26 @@ import numpy as np import pandas as pd import torch -from alibi_detect.cd import ChiSquareDrift, CVMDrift, FETDrift, KSDrift, LSDDDrift, MMDDrift - -from modyn.config.schema.pipeline import AlibiDetectDriftMetric, AlibiDetectMmdDriftMetric, MetricResult -from modyn.config.schema.pipeline.trigger.drift.alibi_detect import AlibiDetectCVMDriftMetric, AlibiDetectKSDriftMetric - -from .drift_detector import DriftDetector +from alibi_detect.cd import ( + ChiSquareDrift, + CVMDrift, + FETDrift, + KSDrift, + LSDDDrift, + MMDDrift, +) + +from modyn.config.schema.pipeline import ( + AlibiDetectDriftMetric, + AlibiDetectMmdDriftMetric, + MetricResult, +) +from modyn.config.schema.pipeline.trigger.drift.alibi_detect import ( + AlibiDetectCVMDriftMetric, + AlibiDetectKSDriftMetric, +) + +from .drift import DriftDetector _AlibiMetrics = MMDDrift | ChiSquareDrift | KSDrift | CVMDrift | FETDrift | LSDDDrift @@ -30,6 +44,7 @@ def detect_drift( self, embeddings_ref: pd.DataFrame | np.ndarray | torch.Tensor, embeddings_cur: pd.DataFrame | np.ndarray | torch.Tensor, + is_warmup: bool, ) -> dict[str, MetricResult]: assert isinstance(embeddings_ref, (np.ndarray | torch.Tensor)) assert isinstance(embeddings_cur, (np.ndarray | torch.Tensor)) @@ -43,27 +58,29 @@ def detect_drift( results: dict[str, MetricResult] = {} for metric_ref, config in self.metrics_config.items(): + if is_warmup and not config.decision_criterion.needs_calibration: + continue + metric = _alibi_detect_metric_factory(config, embeddings_ref) result = metric.predict(embeddings_cur, return_p_val=True, return_distance=True) # type: ignore + + # some metrics return a list of distances (for every sample) instead of a single distance + # we take the mean of the distances to get a scalar distance value _dist = ( - list(result["data"]["distance"]) + float(result["data"]["distance"].mean()) if isinstance(result["data"]["distance"], np.ndarray) else result["data"]["distance"] ) _p_val = ( - list(result["data"]["p_val"]) + float(result["data"]["p_val"].mean()) if isinstance(result["data"]["p_val"], np.ndarray) else result["data"]["p_val"] ) - is_drift = result["data"]["is_drift"] - - if isinstance(config, AlibiDetectMmdDriftMetric) and config.threshold is not None: - is_drift = _dist > config.threshold - results[metric_ref] = MetricResult( metric_id=metric_ref, - is_drift=is_drift, + # will be overwritten by DecisionPolicy inside the DataDriftTrigger + is_drift=result["data"]["is_drift"], distance=_dist, p_val=_p_val, threshold=result["data"].get("threshold"), @@ -99,7 +116,6 @@ def _alibi_detect_metric_factory(config: AlibiDetectDriftMetric, embeddings_ref: return KSDrift( x_ref=embeddings_ref, p_val=config.p_val, - alternative=config.alternative_hypothesis, correction=config.correction, x_ref_preprocessed=config.x_ref_preprocessed, ) diff --git a/modyn/supervisor/internal/triggers/drift/drift_detector.py b/modyn/supervisor/internal/triggers/drift/detector/drift.py similarity index 74% rename from modyn/supervisor/internal/triggers/drift/drift_detector.py rename to modyn/supervisor/internal/triggers/drift/detector/drift.py index b34f419cc..5a0d2bdd7 100644 --- a/modyn/supervisor/internal/triggers/drift/drift_detector.py +++ b/modyn/supervisor/internal/triggers/drift/detector/drift.py @@ -9,7 +9,12 @@ class DriftDetector(ABC): - # tbd.: multiple strategies to select reference data (windowing) + """Base class establishing an abstraction for multiple third party drift + detection libraries. + + Used to create drift distance measurements for different distance + metrics. + """ def __init__(self, metrics_config: dict[str, DriftMetric]): self.metrics_config = metrics_config @@ -21,5 +26,6 @@ def detect_drift( self, embeddings_ref: pd.DataFrame | np.ndarray | torch.Tensor, embeddings_cur: pd.DataFrame | np.ndarray | torch.Tensor, + is_warmup: bool, ) -> dict[str, MetricResult]: raise NotImplementedError() diff --git a/modyn/supervisor/internal/triggers/drift/evidently_detector.py b/modyn/supervisor/internal/triggers/drift/detector/evidently.py similarity index 78% rename from modyn/supervisor/internal/triggers/drift/evidently_detector.py rename to modyn/supervisor/internal/triggers/drift/detector/evidently.py index d6d82f882..8935628ae 100644 --- a/modyn/supervisor/internal/triggers/drift/evidently_detector.py +++ b/modyn/supervisor/internal/triggers/drift/detector/evidently.py @@ -12,7 +12,7 @@ from modyn.config.schema.pipeline import EvidentlyDriftMetric, MetricResult -from .drift_detector import DriftDetector +from .drift import DriftDetector logger = logging.getLogger(__name__) @@ -34,6 +34,7 @@ def detect_drift( self, embeddings_ref: pd.DataFrame | np.ndarray | torch.Tensor, embeddings_cur: pd.DataFrame | np.ndarray | torch.Tensor, + is_warmup: bool, ) -> dict[str, MetricResult]: assert isinstance(embeddings_ref, pd.DataFrame) assert isinstance(embeddings_cur, pd.DataFrame) @@ -45,14 +46,25 @@ def detect_drift( column_mapping = ColumnMapping(embeddings={EVIDENTLY_COLUMN_MAPPING_NAME: embeddings_ref.columns}) # https://docs.evidentlyai.com/user-guide/customization/embeddings-drift-parameters - report = Report(metrics=[self.evidently_metrics[name] for name in self.evidently_metrics]) - report.run(reference_data=embeddings_ref, current_data=embeddings_cur, column_mapping=column_mapping) + report = Report( + metrics=[ + self.evidently_metrics[name][1] + for name in self.evidently_metrics + if not is_warmup or self.evidently_metrics[name][0].decision_criterion.needs_calibration + ] + ) + report.run( + reference_data=embeddings_ref, + current_data=embeddings_cur, + column_mapping=column_mapping, + ) results_raw = report.as_dict() metric_names = list(self.metrics_config) results = { metric_names[metric_idx]: MetricResult( metric_id=metric_result["metric"], + # will be overwritten by DecisionPolicy inside the DataDriftTrigger is_drift=metric_result["result"]["drift_detected"], distance=metric_result["result"]["drift_score"], ) @@ -66,7 +78,9 @@ def detect_drift( # -------------------------------------------------------------------------------------------------------------------- # -def _get_evidently_metrics(metrics_config: dict[str, EvidentlyDriftMetric]) -> dict[str, EmbeddingsDriftMetric]: +def _get_evidently_metrics( + metrics_config: dict[str, EvidentlyDriftMetric], +) -> dict[str, tuple[EvidentlyDriftMetric, EmbeddingsDriftMetric]]: """This function instantiates an Evidently metric given metric configuration. If we want to support multiple metrics in the future, we can loop through the configurations. @@ -77,7 +91,10 @@ def _get_evidently_metrics(metrics_config: dict[str, EvidentlyDriftMetric]) -> d Otherwise, we use the metric given by metric_name, with optional metric configuration specific to the metric. """ metrics = { - metric_ref: EmbeddingsDriftMetric(EVIDENTLY_COLUMN_MAPPING_NAME, _evidently_metric_factory(config)) + metric_ref: ( + config, + EmbeddingsDriftMetric(EVIDENTLY_COLUMN_MAPPING_NAME, _evidently_metric_factory(config)), + ) for metric_ref, config in metrics_config.items() } return metrics @@ -85,9 +102,10 @@ def _get_evidently_metrics(metrics_config: dict[str, EvidentlyDriftMetric]) -> d def _evidently_metric_factory(config: EvidentlyDriftMetric) -> EmbeddingsDriftMetric: if config.id == "EvidentlyModelDriftMetric": + assert config.bootstrap is False, "Bootstrap is not supported in EvidentlyModelDriftMetric." return embedding_drift_methods.model( threshold=config.threshold, - bootstrap=config.bootstrap, + bootstrap=False, quantile_probability=config.quantile_probability, pca_components=config.num_pca_component, ) @@ -99,10 +117,11 @@ def _evidently_metric_factory(config: EvidentlyDriftMetric) -> EmbeddingsDriftMe pca_components=config.num_pca_component, ) if config.id == "EvidentlySimpleDistanceDriftMetric": + assert config.bootstrap is False, "Bootstrap is not supported in EvidentlySimpleDistanceDriftMetric." return embedding_drift_methods.distance( dist=config.distance_metric, threshold=config.threshold, - bootstrap=config.bootstrap, + bootstrap=False, pca_components=config.num_pca_component, quantile_probability=config.quantile_probability, ) diff --git a/modyn/supervisor/internal/triggers/models.py b/modyn/supervisor/internal/triggers/models.py index 5c7a813b2..9f21b486d 100644 --- a/modyn/supervisor/internal/triggers/models.py +++ b/modyn/supervisor/internal/triggers/models.py @@ -8,12 +8,13 @@ class DriftTriggerEvalLog(ModynBaseModel): type: Literal["drift"] = Field("drift") - detection_idx_start: int - detection_idx_end: int + detection_interval: tuple[int, int] # timestamps of the current detection interval + reference_interval: tuple[int | None, int | None] = Field((None, None)) # timestamps of the reference interval triggered: bool trigger_index: int | None = Field(None) data_points: int = Field(0) drift_results: dict[str, MetricResult] = Field(default_factory=dict) + is_warmup: bool = Field(False) TriggerEvalLog = Annotated[DriftTriggerEvalLog, Field(discriminator="type")] diff --git a/modyn/supervisor/internal/triggers/utils/__init__.py b/modyn/supervisor/internal/triggers/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modyn/supervisor/internal/triggers/utils/factory.py b/modyn/supervisor/internal/triggers/utils/factory.py new file mode 100644 index 000000000..1cb14c8ab --- /dev/null +++ b/modyn/supervisor/internal/triggers/utils/factory.py @@ -0,0 +1,10 @@ +from modyn.config.schema.base_model import ModynBaseModel +from modyn.supervisor.internal.triggers.trigger import Trigger +from modyn.utils import dynamic_module_import + + +def instantiate_trigger(trigger_id: str, trigger_config: ModynBaseModel) -> Trigger: + trigger_module = dynamic_module_import("modyn.supervisor.internal.triggers") + trigger: Trigger = getattr(trigger_module, trigger_id)(trigger_config) + assert trigger is not None, "Error during trigger initialization" + return trigger diff --git a/modyn/supervisor/internal/triggers/utils.py b/modyn/supervisor/internal/triggers/utils/utils.py similarity index 82% rename from modyn/supervisor/internal/triggers/utils.py rename to modyn/supervisor/internal/triggers/utils/utils.py index e2805b3ca..bf701fd9d 100644 --- a/modyn/supervisor/internal/triggers/utils.py +++ b/modyn/supervisor/internal/triggers/utils/utils.py @@ -5,8 +5,18 @@ import torch from torch.utils.data import DataLoader -from modyn.supervisor.internal.triggers.embedding_encoder_utils import EmbeddingEncoder -from modyn.supervisor.internal.triggers.trigger_datasets import DataLoaderInfo, FixedKeysDataset, OnlineTriggerDataset +from modyn.supervisor.internal.triggers.embedding_encoder_utils.embedding_encoder import ( + EmbeddingEncoder, +) +from modyn.supervisor.internal.triggers.trigger_datasets.dataloader_info import ( + DataLoaderInfo, +) +from modyn.supervisor.internal.triggers.trigger_datasets.fixed_keys_dataset import ( + FixedKeysDataset, +) +from modyn.supervisor.internal.triggers.trigger_datasets.online_trigger_dataset import ( + OnlineTriggerDataset, +) logger = logging.getLogger(__name__) @@ -38,7 +48,11 @@ def prepare_trigger_dataloader_by_trigger( ) logger.debug("Creating online trigger DataLoader.") - return DataLoader(train_set, batch_size=dataloader_info.batch_size, num_workers=dataloader_info.num_dataloaders) + return DataLoader( + train_set, + batch_size=dataloader_info.batch_size, + num_workers=dataloader_info.num_dataloaders, + ) def prepare_trigger_dataloader_fixed_keys( @@ -59,7 +73,11 @@ def prepare_trigger_dataloader_fixed_keys( ) logger.debug("Creating fixed keys DataLoader.") - return DataLoader(train_set, batch_size=dataloader_info.batch_size, num_workers=dataloader_info.num_dataloaders) + return DataLoader( + train_set, + batch_size=dataloader_info.batch_size, + num_workers=dataloader_info.num_dataloaders, + ) def get_embeddings(embedding_encoder: EmbeddingEncoder, dataloader: DataLoader) -> torch.Tensor: diff --git a/modyn/tests/supervisor/internal/triggers/drift/detection_window/amount.py b/modyn/tests/supervisor/internal/triggers/drift/detection_window/amount.py new file mode 100644 index 000000000..11ba6274f --- /dev/null +++ b/modyn/tests/supervisor/internal/triggers/drift/detection_window/amount.py @@ -0,0 +1,294 @@ +from modyn.config.schema.pipeline.trigger.drift.detection_window import AmountWindowingStrategy +from modyn.supervisor.internal.triggers.drift.detection_window.amount import AmountDetectionWindows + + +def test_amount_detection_window_manager_no_overlap() -> None: + config = AmountWindowingStrategy(amount_cur=3, amount_ref=5, allow_overlap=False) + assert config.amount_cur == 3 + assert config.amount_ref == 5 + assert not config.allow_overlap + + windows = AmountDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 100), (2, 200)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(1, 100), (2, 200)] + + # current_ overflow -> fill current_reservoir_ + windows.inform_data([(3, 300), (4, 400)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 1 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(2, 200), (3, 300), (4, 400)] + assert list(windows.current_reservoir) == [(1, 100)] + + # overflow current_ and current_reservoir_ + windows.inform_data([(5, 500), (6, 600)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 2 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(4, 400), (5, 500), (6, 600)] + assert list(windows.current_reservoir) == [(2, 200), (3, 300)] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [ + (2, 200), + (3, 300), + (4, 400), + (5, 500), + (6, 600), + ] + + windows.inform_data([(7, 700), (8, 800)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(7, 700), (8, 800)] + assert list(windows.reference) == [ + (2, 200), + (3, 300), + (4, 400), + (5, 500), + (6, 600), + ] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [ + (4, 400), + (5, 500), + (6, 600), + (7, 700), + (8, 800), + ] + + +def test_amount_detection_window_manager_no_overlap_ref_smaller_cur() -> None: + config = AmountWindowingStrategy(amount_cur=5, amount_ref=3, allow_overlap=False) + assert config.amount_cur == 5 + assert config.amount_ref == 3 + assert not config.allow_overlap + + windows = AmountDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 100), (2, 200)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(1, 100), (2, 200)] + + # current_ overflow + windows.inform_data([(3, 300), (4, 400), (5, 500), (6, 600)]) + assert len(windows.current) == 5 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(2, 200), (3, 300), (4, 400), (5, 500), (6, 600)] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [(4, 400), (5, 500), (6, 600)] + + windows.inform_data([(7, 700), (8, 800)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(7, 700), (8, 800)] + assert list(windows.reference) == [(4, 400), (5, 500), (6, 600)] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [(6, 600), (7, 700), (8, 800)] + + +def test_amount_detection_window_manager_with_overlap() -> None: + config = AmountWindowingStrategy(amount_cur=3, amount_ref=5, allow_overlap=True) + assert config.amount_cur == 3 + assert config.amount_ref == 5 + assert config.allow_overlap + + windows = AmountDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 100), (2, 200)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(1, 100), (2, 200)] + assert list(windows.exclusive_current) == [(1, 100), (2, 200)] + + # current_ overflow + windows.inform_data([(3, 300), (4, 400)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 4 + assert list(windows.current) == [(2, 200), (3, 300), (4, 400)] + assert list(windows.exclusive_current) == [(1, 100), (2, 200), (3, 300), (4, 400)] + + # overflow current_ and exclusive_current + windows.inform_data([(5, 500), (6, 600)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 5 + assert list(windows.current) == [(4, 400), (5, 500), (6, 600)] + assert list(windows.exclusive_current) == [ + (2, 200), + (3, 300), + (4, 400), + (5, 500), + (6, 600), + ] + + # trigger: DONT reset current_ but copy data to reference_ + windows.inform_trigger() + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(4, 400), (5, 500), (6, 600)] + assert list(windows.reference) == [ + (2, 200), + (3, 300), + (4, 400), + (5, 500), + (6, 600), + ] + + windows.inform_data([(7, 700), (8, 800)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(6, 600), (7, 700), (8, 800)] + assert list(windows.reference) == [ + (2, 200), + (3, 300), + (4, 400), + (5, 500), + (6, 600), + ] + assert list(windows.exclusive_current) == [(7, 700), (8, 800)] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 5 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(6, 600), (7, 700), (8, 800)] + assert list(windows.reference) == [ + (4, 400), + (5, 500), + (6, 600), + (7, 700), + (8, 800), + ] + + +def test_amount_detection_window_manager_with_overlap_ref_smaller_cur() -> None: + config = AmountWindowingStrategy(amount_cur=5, amount_ref=3, allow_overlap=True) + assert config.amount_cur == 5 + assert config.amount_ref == 3 + assert config.allow_overlap + + windows = AmountDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 100), (2, 200)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(1, 100), (2, 200)] + assert list(windows.exclusive_current) == [(1, 100), (2, 200)] + + # current_ overflow + windows.inform_data([(3, 300), (4, 400)]) + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 3 + assert list(windows.current) == [(1, 100), (2, 200), (3, 300), (4, 400)] + assert list(windows.exclusive_current) == [(2, 200), (3, 300), (4, 400)] + + # overflow current_ and exclusive_current + windows.inform_data([(5, 500), (6, 600)]) + assert len(windows.current) == 5 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 3 + assert list(windows.current) == [(2, 200), (3, 300), (4, 400), (5, 500), (6, 600)] + assert list(windows.exclusive_current) == [(4, 400), (5, 500), (6, 600)] + + # trigger: DONT reset current_ but copy data to reference_ + windows.inform_trigger() + assert len(windows.current) == 5 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(2, 200), (3, 300), (4, 400), (5, 500), (6, 600)] + assert list(windows.reference) == [(4, 400), (5, 500), (6, 600)] + + windows.inform_data([(7, 700), (8, 800)]) + assert len(windows.current) == 5 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(4, 400), (5, 500), (6, 600), (7, 700), (8, 800)] + assert list(windows.reference) == [(4, 400), (5, 500), (6, 600)] + assert list(windows.exclusive_current) == [(7, 700), (8, 800)] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 5 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 3 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(4, 400), (5, 500), (6, 600), (7, 700), (8, 800)] + assert list(windows.reference) == [(6, 600), (7, 700), (8, 800)] diff --git a/modyn/tests/supervisor/internal/triggers/drift/detection_window/time_.py b/modyn/tests/supervisor/internal/triggers/drift/detection_window/time_.py new file mode 100644 index 000000000..64c1d5df9 --- /dev/null +++ b/modyn/tests/supervisor/internal/triggers/drift/detection_window/time_.py @@ -0,0 +1,291 @@ +from modyn.config.schema.pipeline.trigger.drift.detection_window import TimeWindowingStrategy +from modyn.supervisor.internal.triggers.drift.detection_window.time_ import TimeDetectionWindows + + +def test_time_detection_window_manager_no_overlap() -> None: + config = TimeWindowingStrategy(limit_cur="50s", limit_ref="100s", allow_overlap=False) + assert config.limit_cur == "50s" + assert config.limit_ref == "100s" + assert not config.allow_overlap + + windows = TimeDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 15), (2, 30)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(1, 15), (2, 30)] + + # current_ overflow -> fill current_reservoir_ + windows.inform_data([(3, 45), (4, 60)]) + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(1, 15), (2, 30), (3, 45), (4, 60)] + + # overflow current_ and current_reservoir_ + windows.inform_data([(5, 75), (6, 90), (7, 120)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 3 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(5, 75), (6, 90), (7, 120)] + assert list(windows.current_reservoir) == [(2, 30), (3, 45), (4, 60)] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [ + (2, 30), + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + ] + + windows.inform_data([(8, 135)]) + assert len(windows.current) == 1 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(8, 135)] + assert list(windows.reference) == [ + (2, 30), + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + ] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [ + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + (8, 135), + ] + + +def test_time_detection_window_manager_no_overlap_ref_smaller_cur() -> None: + config = TimeWindowingStrategy(limit_cur="100s", limit_ref="50s", allow_overlap=False) + assert config.limit_cur == "100s" + assert config.limit_ref == "50s" + assert not config.allow_overlap + + windows = TimeDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # fill current_ + windows.inform_data([(1, 15), (2, 30)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(1, 15), (2, 30)] + + # current_ overflow + windows.inform_data([(3, 45), (4, 60), (5, 75), (6, 90), (7, 150)]) + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(4, 60), (5, 75), (6, 90), (7, 150)] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 1 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [(7, 150)] + + windows.inform_data([(8, 190), (9, 210)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 1 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(8, 190), (9, 210)] + assert list(windows.reference) == [(7, 150)] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 2 + assert len(windows.exclusive_current) == 0 + assert list(windows.reference) == [(8, 190), (9, 210)] + + +def test_time_detection_window_manager_with_overlap() -> None: + config = TimeWindowingStrategy(limit_cur="50s", limit_ref="100s", allow_overlap=True) + assert config.limit_cur == "50s" + assert config.limit_ref == "100s" + assert config.allow_overlap + + windows = TimeDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # partial fill current_ + windows.inform_data([(1, 15), (2, 30)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(1, 15), (2, 30)] + assert list(windows.exclusive_current) == [(1, 15), (2, 30)] + + # current_ overflow -> fill current_reservoir_ + windows.inform_data([(3, 45), (4, 60)]) + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 4 + assert list(windows.current) == [(1, 15), (2, 30), (3, 45), (4, 60)] + assert list(windows.exclusive_current) == [(1, 15), (2, 30), (3, 45), (4, 60)] + + # overflow current_ and current_reservoir_ + windows.inform_data([(5, 75), (6, 90), (7, 120)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 6 + assert list(windows.current) == [(5, 75), (6, 90), (7, 120)] + assert list(windows.exclusive_current) == [ + (2, 30), + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + ] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(5, 75), (6, 90), (7, 120)] + assert list(windows.reference) == [ + (2, 30), + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + ] + + windows.inform_data([(8, 135)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 1 + assert list(windows.current) == [(6, 90), (7, 120), (8, 135)] + assert list(windows.exclusive_current) == [(8, 135)] + assert list(windows.reference) == [ + (2, 30), + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + ] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 6 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(6, 90), (7, 120), (8, 135)] + assert list(windows.reference) == [ + (3, 45), + (4, 60), + (5, 75), + (6, 90), + (7, 120), + (8, 135), + ] + + +def test_time_detection_window_manager_with_overlap_ref_smaller_cur() -> None: + config = TimeWindowingStrategy(limit_cur="100s", limit_ref="50s", allow_overlap=True) + assert config.limit_cur == "100s" + assert config.limit_ref == "50s" + assert config.allow_overlap + + windows = TimeDetectionWindows(config) + assert len(windows.current) == 0 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 0 + + # fill current_ + windows.inform_data([(1, 15), (2, 30)]) + assert len(windows.current) == 2 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(1, 15), (2, 30)] + assert list(windows.exclusive_current) == [(1, 15), (2, 30)] + + # current_ overflow + windows.inform_data([(3, 45), (4, 60), (5, 75), (6, 90), (7, 150)]) + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 0 + assert len(windows.exclusive_current) == 1 + assert list(windows.current) == [(4, 60), (5, 75), (6, 90), (7, 150)] + assert list(windows.exclusive_current) == [(7, 150)] + + # trigger: reset current_ and move data to reference_ + windows.inform_trigger() + assert len(windows.current) == 4 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 1 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(4, 60), (5, 75), (6, 90), (7, 150)] + assert list(windows.reference) == [(7, 150)] + + windows.inform_data([(8, 190), (9, 210)]) + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 1 + assert len(windows.exclusive_current) == 2 + assert list(windows.current) == [(7, 150), (8, 190), (9, 210)] + assert list(windows.reference) == [(7, 150)] + assert list(windows.exclusive_current) == [(8, 190), (9, 210)] + + # test ref overflow + windows.inform_trigger() + assert len(windows.current) == 3 + assert len(windows.current_reservoir) == 0 + assert len(windows.reference) == 2 + assert len(windows.exclusive_current) == 0 + assert list(windows.current) == [(7, 150), (8, 190), (9, 210)] + assert list(windows.reference) == [(8, 190), (9, 210)] diff --git a/modyn/tests/supervisor/internal/triggers/drift/test_alibi_detector.py b/modyn/tests/supervisor/internal/triggers/drift/test_alibi_detector.py index 44afd6954..8a0a24e82 100644 --- a/modyn/tests/supervisor/internal/triggers/drift/test_alibi_detector.py +++ b/modyn/tests/supervisor/internal/triggers/drift/test_alibi_detector.py @@ -4,7 +4,8 @@ import pytest from modyn.config.schema.pipeline import AlibiDetectCVMDriftMetric, AlibiDetectKSDriftMetric, AlibiDetectMmdDriftMetric -from modyn.supervisor.internal.triggers.drift.alibi_detector import AlibiDriftDetector +from modyn.config.schema.pipeline.trigger.drift.metric import ThresholdDecisionCriterion +from modyn.supervisor.internal.triggers.drift.detector.alibi import AlibiDriftDetector @pytest.fixture @@ -14,6 +15,7 @@ def mmd_drift_metric() -> AlibiDetectMmdDriftMetric: device=None, num_permutations=100, kernel="GaussianRBF", + decision_criterion=ThresholdDecisionCriterion(threshold=0.2), ) @@ -22,7 +24,7 @@ def ks_drift_metric() -> AlibiDetectKSDriftMetric: return AlibiDetectKSDriftMetric( p_val=0.05, correction="bonferroni", - alternative_hypothesis="two-sided", + decision_criterion=ThresholdDecisionCriterion(threshold=0.2), ) @@ -31,6 +33,7 @@ def cvm_drift_metric() -> AlibiDetectCVMDriftMetric: return AlibiDetectCVMDriftMetric( p_val=0.05, correction="bonferroni", + decision_criterion=ThresholdDecisionCriterion(threshold=0.2), ) @@ -51,11 +54,11 @@ def test_alibi_detect_drift_metric( assert isinstance(ad, AlibiDriftDetector) # on h0 - results = ad.detect_drift(data_ref, data_h0) + results = ad.detect_drift(data_ref, data_h0, False) assert not results[name].is_drift # on current data - results = ad.detect_drift(data_ref, data_cur) + results = ad.detect_drift(data_ref, data_cur, False) assert results[name].is_drift assert ( 0 @@ -81,5 +84,5 @@ def test_alibi_detect_drift_metric( "cvm": cvm_drift_metric, } ) - results = ad.detect_drift(data_ref, data_cur) + results = ad.detect_drift(data_ref, data_cur, False) assert len(results) == 3 diff --git a/modyn/tests/supervisor/internal/triggers/drift/test_decision_policy.py b/modyn/tests/supervisor/internal/triggers/drift/test_decision_policy.py new file mode 100644 index 000000000..bed6b5130 --- /dev/null +++ b/modyn/tests/supervisor/internal/triggers/drift/test_decision_policy.py @@ -0,0 +1,63 @@ +import pytest + +from modyn.config.schema.pipeline.trigger.drift.metric import DynamicThresholdCriterion, ThresholdDecisionCriterion +from modyn.supervisor.internal.triggers.drift.decision_policy import DynamicDecisionPolicy, ThresholdDecisionPolicy + + +def test_threshold_decision_policy() -> None: + config = ThresholdDecisionCriterion(threshold=0.5) + policy = ThresholdDecisionPolicy(config) + + assert policy.evaluate_decision(0.6) + assert not policy.evaluate_decision(0.4) + + +@pytest.mark.parametrize("percentile", [0.1, 0.5, 0.9]) +def test_dynamic_decision_policy_initial(percentile: float) -> None: + config = DynamicThresholdCriterion(window_size=3, percentile=percentile) + policy = DynamicDecisionPolicy(config) + + # Initially, the deque is empty, so any value should trigger a drift + assert policy.evaluate_decision(0.5) + + +def test_dynamic_decision_policy_with_observations() -> None: + config = DynamicThresholdCriterion(window_size=3, percentile=0.5) + policy = DynamicDecisionPolicy(config) + + # Add initial observations + policy.score_observations.extend([0.4, 0.6, 0.7]) + + # Testing with various distances + assert not policy.evaluate_decision(0.3) # Less than all observations + assert policy.evaluate_decision(0.8) # Greater than all observations + assert not policy.evaluate_decision(0.5) # 0.5 is at the 50th percentile + + +def test_dynamic_decision_policy_window_size() -> None: + config = DynamicThresholdCriterion(window_size=3, percentile=0.5) + policy = DynamicDecisionPolicy(config) + + # Add observations to fill the window + policy.evaluate_decision(0.4) + policy.evaluate_decision(0.6) + policy.evaluate_decision(0.7) + + # Adding another observation should remove the oldest one (0.4) + assert policy.evaluate_decision(0.8) # Greater than all observations + assert len(policy.score_observations) == 3 # Ensure the deque is still at max length + + +def test_dynamic_decision_policy_percentile() -> None: + config = DynamicThresholdCriterion(window_size=4, percentile=0.75) + policy = DynamicDecisionPolicy(config) + + # Add observations + policy.evaluate_decision(0.4) + policy.evaluate_decision(0.6) + policy.evaluate_decision(0.7) + policy.evaluate_decision(0.9) + + assert not policy.evaluate_decision(0.5) + assert policy.evaluate_decision(0.8) + assert not policy.evaluate_decision(0.7) diff --git a/modyn/tests/supervisor/internal/triggers/drift/test_evidently_detector.py b/modyn/tests/supervisor/internal/triggers/drift/test_evidently_detector.py index 810a3c815..36d8f5302 100644 --- a/modyn/tests/supervisor/internal/triggers/drift/test_evidently_detector.py +++ b/modyn/tests/supervisor/internal/triggers/drift/test_evidently_detector.py @@ -7,7 +7,8 @@ EvidentlyRatioDriftMetric, EvidentlySimpleDistanceDriftMetric, ) -from modyn.supervisor.internal.triggers.drift.evidently_detector import EvidentlyDriftDetector +from modyn.config.schema.pipeline.trigger.drift.metric import DynamicThresholdCriterion +from modyn.supervisor.internal.triggers.drift.detector.evidently import EvidentlyDriftDetector def _add_col_prefixes(df: pd.DataFrame, prefix: str) -> pd.DataFrame: @@ -32,19 +33,20 @@ def df_data_cur(data_cur: np.ndarray) -> pd.DataFrame: @pytest.fixture def model_drift_metric() -> EvidentlyModelDriftMetric: - return EvidentlyModelDriftMetric(bootstrap=True) + return EvidentlyModelDriftMetric(bootstrap=False, decision_criterion=DynamicThresholdCriterion()) @pytest.fixture def ratio_drift_metric() -> EvidentlyRatioDriftMetric: - return EvidentlyRatioDriftMetric() + return EvidentlyRatioDriftMetric(decision_criterion=DynamicThresholdCriterion()) @pytest.fixture def simple_distance_drift_metric() -> EvidentlySimpleDistanceDriftMetric: return EvidentlySimpleDistanceDriftMetric( - bootstrap=True, + bootstrap=False, distance_metric="euclidean", + decision_criterion=DynamicThresholdCriterion(), ) @@ -59,17 +61,20 @@ def test_evidently_detect_drift_metric( detector = [ ("model", EvidentlyDriftDetector({"model": model_drift_metric})), ("ratio", EvidentlyDriftDetector({"ratio": ratio_drift_metric})), - ("simple_distance", EvidentlyDriftDetector({"simple_distance": simple_distance_drift_metric})), + ( + "simple_distance", + EvidentlyDriftDetector({"simple_distance": simple_distance_drift_metric}), + ), ] for name, ad in detector: assert isinstance(ad, EvidentlyDriftDetector) # on h0 - results = ad.detect_drift(df_data_ref, df_data_h0) + results = ad.detect_drift(df_data_ref, df_data_h0, False) assert not results[name].is_drift # on current data - results = ad.detect_drift(df_data_ref, df_data_cur) + results = ad.detect_drift(df_data_ref, df_data_cur, False) if name != "model": # model makes the wrong decision here assert results[name].is_drift @@ -83,5 +88,5 @@ def test_evidently_detect_drift_metric( "simple_distance": simple_distance_drift_metric, } ) - results = ad.detect_drift(df_data_ref, df_data_cur) + results = ad.detect_drift(df_data_ref, df_data_cur, False) assert len(results) == 3 diff --git a/modyn/tests/supervisor/internal/triggers/test_datadrifttrigger.py b/modyn/tests/supervisor/internal/triggers/test_datadrifttrigger.py index 242e5f59a..6903824c5 100644 --- a/modyn/tests/supervisor/internal/triggers/test_datadrifttrigger.py +++ b/modyn/tests/supervisor/internal/triggers/test_datadrifttrigger.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument, no-name-in-module, no-value-for-parameter import os import pathlib -from unittest.mock import patch +from unittest.mock import MagicMock, patch from pytest import fixture @@ -12,12 +12,20 @@ from modyn.config.schema.pipeline.trigger.drift.alibi_detect import ( AlibiDetectMmdDriftMetric, ) -from modyn.config.schema.pipeline.trigger.drift.config import ( - AmountWindowingStrategy, +from modyn.config.schema.pipeline.trigger.drift.config import AmountWindowingStrategy +from modyn.config.schema.pipeline.trigger.drift.detection_window import ( TimeWindowingStrategy, ) +from modyn.config.schema.pipeline.trigger.drift.metric import ( + DynamicThresholdCriterion, + ThresholdDecisionCriterion, +) +from modyn.config.schema.pipeline.trigger.simple.data_amount import ( + DataAmountTriggerConfig, +) from modyn.config.schema.system.config import ModynConfig from modyn.supervisor.internal.triggers import DataDriftTrigger +from modyn.supervisor.internal.triggers.amounttrigger import DataAmountTrigger from modyn.supervisor.internal.triggers.embedding_encoder_utils import ( EmbeddingEncoderDownloader, ) @@ -33,7 +41,13 @@ def drift_trigger_config() -> DataDriftTriggerConfig: return DataDriftTriggerConfig( detection_interval_data_points=42, - metrics={"model": AlibiDetectMmdDriftMetric(num_permutations=1000)}, + metrics={ + "mmd": AlibiDetectMmdDriftMetric( + decision_criterion=ThresholdDecisionCriterion(threshold=0.5), + num_permutations=None, + threshold=0.5, + ) + }, aggregation_strategy=MajorityVoteDriftAggregationStrategy(), ) @@ -103,13 +117,15 @@ def test_init_trigger( def test_inform_previous_model_id(drift_trigger_config: DataDriftTriggerConfig) -> None: trigger = DataDriftTrigger(drift_trigger_config) + trigger.model_updated = False # pylint: disable-next=use-implicit-booleaness-not-comparison trigger.inform_previous_model(42) assert trigger.previous_model_id == 42 + assert trigger.model_updated @patch.object(DataDriftTrigger, "_run_detection", return_value=(True, {})) -def test_inform_always_drift(test_detect_drift, drift_trigger_config: DataDriftTriggerConfig) -> None: +def test_inform_always_drift(test_detect_drift: MagicMock, drift_trigger_config: DataDriftTriggerConfig) -> None: drift_trigger_config.detection_interval_data_points = 1 trigger = DataDriftTrigger(drift_trigger_config) num_triggers = 0 @@ -139,14 +155,13 @@ def test_inform_always_drift(test_detect_drift, drift_trigger_config: DataDriftT @patch.object(DataDriftTrigger, "_run_detection", return_value=(False, {})) -def test_inform_no_drift(test_detect_no_drift, drift_trigger_config: DataDriftTriggerConfig) -> None: +def test_inform_no_drift(test_detect_no_drift: MagicMock, drift_trigger_config: DataDriftTriggerConfig) -> None: drift_trigger_config.detection_interval_data_points = 1 trigger = DataDriftTrigger(drift_trigger_config) num_triggers = 0 for _ in trigger.inform([SAMPLE, SAMPLE, SAMPLE, SAMPLE, SAMPLE]): num_triggers += 1 trigger.inform_previous_model(num_triggers) - # pylint: disable-next=use-implicit-booleaness-not-comparison assert num_triggers == 1 drift_trigger_config.detection_interval_data_points = 2 @@ -155,7 +170,6 @@ def test_inform_no_drift(test_detect_no_drift, drift_trigger_config: DataDriftTr for _ in trigger.inform([SAMPLE, SAMPLE, SAMPLE, SAMPLE, SAMPLE]): num_triggers += 1 trigger.inform_previous_model(num_triggers) - # pylint: disable-next=use-implicit-booleaness-not-comparison assert num_triggers == 1 drift_trigger_config.detection_interval_data_points = 5 @@ -171,97 +185,161 @@ def test_inform_no_drift(test_detect_no_drift, drift_trigger_config: DataDriftTr def test_update_current_window_amount_strategy( drift_trigger_config: DataDriftTriggerConfig, ) -> None: - drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount=3) + drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount_cur=3, amount_ref=3) drift_trigger_config.detection_interval_data_points = 100 trigger = DataDriftTrigger(drift_trigger_config) # Inform with less data than the window amount list(trigger.inform([(1, 100, 1), (2, 101, 1)])) - assert len(trigger._current_window) == 2, "Current window should contain 2 data points." + assert len(trigger._windows.current) == 2, "Current window should contain 2 data points." # Inform with additional data points to exceed the window size list(trigger.inform([(3, 102, 1), (4, 103, 1)])) - assert len(trigger._current_window) == 3, "Current window should not exceed 3 data points." - assert trigger._current_window[0][0] == 2, "Oldest data point should be dropped." + assert len(trigger._windows.current) == 3, "Current window should not exceed 3 data points." + assert trigger._windows.current[0][0] == 2, "Oldest data point should be dropped." def test_time_windowing_strategy_update( drift_trigger_config: DataDriftTriggerConfig, ) -> None: - drift_trigger_config.windowing_strategy = TimeWindowingStrategy(limit="10s") + drift_trigger_config.windowing_strategy = TimeWindowingStrategy(limit_cur="10s", limit_ref="10s") trigger = DataDriftTrigger(drift_trigger_config) # Inform with initial data points list(trigger.inform([(1, 100, 1), (2, 104, 1), (3, 105, 1)])) - assert len(trigger._current_window) == 3, "Current window should contain 3 data points." + assert len(trigger._windows.current) == 3, "Current window should contain 3 data points." # Inform with additional data points outside the time window list(trigger.inform([(4, 111, 1), (5, 115, 1)])) - assert len(trigger._current_window) == 3, "Current window should contain only recent data within 10 seconds." + assert len(trigger._windows.current) == 3, "Current window should contain only recent data within 10 seconds." # Since the window is inclusive, we have 105 in there! - assert trigger._current_window[0][0] == 3, "Data points outside the time window should be dropped." + assert trigger._windows.current[0][0] == 3, "Data points outside the time window should be dropped." -@patch.object(DataDriftTrigger, "_run_detection", return_value=(True, {})) +@patch.object(DataDriftTrigger, "_run_detection", return_value=(False, {})) def test_update_current_window_amount_strategy_cross_inform( + drift_trigger: DataDriftTrigger, drift_trigger_config: DataDriftTriggerConfig, ) -> None: - drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount=5) + drift_trigger_config.warmup_intervals = 0 + drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount_cur=5, amount_ref=5) drift_trigger_config.detection_interval_data_points = 3 - # TODO(MaxiBoether/robinholzi: If this is not set, - # it seems to use True, despite the default in the config being False - # Why could this happen? - drift_trigger_config.reset_current_window_on_trigger = False trigger = DataDriftTrigger(drift_trigger_config) - assert list( - trigger.inform( - [ - (1, 100, 1), - (2, 100, 1), - (3, 100, 1), - (4, 100, 1), - (5, 100, 1), - (6, 100, 1), - (7, 100, 1), - ] + assert ( + len( + list( + trigger.inform( + [ + (1, 100, 1), + (2, 100, 1), + (3, 100, 1), + (4, 100, 1), + (5, 100, 1), + (6, 100, 1), + (7, 100, 1), + ] + ) + ) ) - ) == [2, 5] - assert len(trigger._current_window) == 5 - assert trigger._total_items_in_current_window == 7 + == 1 + ), "Only the first batch should trigger." + assert len(trigger._windows.current) == 4 assert len(list(trigger.inform([(8, 100, 1)]))) == 0 - assert len(trigger._current_window) == 5 - assert trigger._total_items_in_current_window == 8 - assert trigger._current_window[0][0] == 4 - - assert list(trigger.inform([(9, 100, 1)])) == [0] - assert len(trigger._current_window) == 5 - assert trigger._total_items_in_current_window == 9 - assert trigger._current_window[0][0] == 5 - - -@patch.object(DataDriftTrigger, "_run_detection", return_value=(True, {})) -def test_leftover_data_handling_with_reset(mock_run_detection, drift_trigger_config: DataDriftTriggerConfig) -> None: - drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount=50) - drift_trigger_config.detection_interval_data_points = 2 - drift_trigger_config.reset_current_window_on_trigger = True - trigger = DataDriftTrigger(drift_trigger_config) + assert len(trigger._windows.current) == 5 + assert trigger._windows.current[0][0] == 4 - # Inform with a batch of data points triggering detection - list(trigger.inform([(1, 100, 1), (2, 101, 1), (3, 102, 1)])) - assert len(trigger._current_window) == 1, "Current window should have leftover data after detection." - assert trigger._current_window[0][0] == 3, "Leftover data should be the last informed data point." + assert len(list(trigger.inform([(9, 100, 1)]))) == 0, "Only the first batch should trigger." + assert len(trigger._windows.current) == 5 + assert trigger._windows.current[0][0] == 5 -@patch.object(DataDriftTrigger, "_run_detection", return_value=(True, {})) -def test_leftover_data_handling_without_reset(mock_run_detection, drift_trigger_config: DataDriftTriggerConfig) -> None: - drift_trigger_config.windowing_strategy = AmountWindowingStrategy(amount=50) - drift_trigger_config.detection_interval_data_points = 2 - drift_trigger_config.reset_current_window_on_trigger = False - trigger = DataDriftTrigger(drift_trigger_config) +@patch.object( + DataDriftTrigger, + "_run_detection", + side_effect=[(False, {})] * 5 + [(False, {}), (True, {}), (False, {})], # first 5: warmup +) +def test_warmup_trigger(drift_trigger: DataDriftTrigger) -> None: + trigger_config = DataDriftTriggerConfig( + detection_interval_data_points=5, + metrics={ + "mmd": AlibiDetectMmdDriftMetric( + decision_criterion=DynamicThresholdCriterion(percentile=50, window_size=3), + ) + }, + aggregation_strategy=MajorityVoteDriftAggregationStrategy(), + windowing_strategy=AmountWindowingStrategy(amount_cur=3, amount_ref=3), + warmup_intervals=5, + warmup_policy=DataAmountTriggerConfig(num_samples=7), + ) + trigger = DataDriftTrigger(trigger_config) + assert isinstance(trigger.warmup_trigger, DataAmountTrigger) + assert ( + len(trigger._windows.current) == len(trigger._windows.reference) == len(trigger._windows.current_reservoir) == 0 + ) - # Inform with a batch of data points triggering detection - list(trigger.inform([(1, 100, 1), (2, 101, 1), (3, 102, 1)])) - assert len(trigger._current_window) == 3, "Current window should have leftover data after detection." - assert trigger._current_window[0][0] == 1, "Leftover data should be the first informed data point." + # Test: We add samples from 0 to 40 in 8 batches of 5 samples each and inspect the trigger state after each batch. + + # with `detection_interval_data_points=5` we will detect drift every 5 samples at + # the following indices: 5, 10, 15, 20, 25, 30, 35 + + # Here are the reasons for the decisions we make at each of these points: + # - index 4: first detection: always trigger + # - index 9: 2nd detection: warmup trigger (warmup policy trigger at index 6) + # - index 14: 3rd detection: warmup trigger (warmup policy trigger at index 13) + # - index 19: 4th detection: no warmup trigger + # - index 24: 5th detection: warmup trigger (warmup policy trigger at index 20) + # - index 29: drift detection: run_detection --> False + # - index 34: drift detection: run_detection --> True + # - index 39: drift detection: run_detection --> False + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(5)])) + assert results == [4] + assert not trigger.warmup_completed + assert trigger.warmup_intervals[-1] == [(i, 100 + i) for i in [2, 3, 4]] # window size 3 + assert len(trigger._windows.reference) == 3 + assert len(trigger._windows.current) == 0 # after a trigger the current window is empty + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(5, 10)])) + assert results == [4] # index in last inform batch + assert len(trigger.warmup_intervals) == 2 + assert not trigger.warmup_completed + assert trigger.warmup_intervals[-1] == [(i, 100 + i) for i in [2, 3, 4]] # from first trigger + assert len(trigger._windows.reference) == 3 + assert len(trigger._windows.current) == 0 # after a trigger the current window is empty + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(10, 15)])) + assert results == [4] + assert len(trigger.warmup_intervals) == 3 + assert not trigger.warmup_completed + assert trigger.warmup_intervals[-1] == [(i, 100 + i) for i in [7, 8, 9]] + assert len(trigger._windows.reference) == 3 + assert len(trigger._windows.current) == 0 # after a trigger the current window is empty + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(15, 20)])) + assert len(results) == 0 + assert len(trigger.warmup_intervals) == 4 + assert not trigger.warmup_completed + assert trigger.warmup_intervals[-1] == [(i, 100 + i) for i in [12, 13, 14]] + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(20, 25)])) + assert results == [4] + assert len(trigger.warmup_intervals) == 5 + assert not trigger.warmup_completed + assert trigger.warmup_intervals[-1] == [(i, 100 + i) for i in [17, 18, 19]] + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(25, 30)])) + assert len(results) == 0 + assert len(trigger.warmup_intervals) == 0 + assert trigger.warmup_completed + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(30, 35)])) + assert results == [4] + assert len(trigger.warmup_intervals) == 0 + assert trigger.warmup_completed + + results = list(trigger.inform([(i, 100 + i, 1) for i in range(35, 40)])) + assert len(results) == 0 + assert len(trigger.warmup_intervals) == 0 + assert trigger.warmup_completed