Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add more powerful drift windowing strategies, warmup and dynamic thresholds #564

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9d71dfb
Add more powerful windowing strategies, warumup and dynamic thresholds
robinholzi Jul 4, 2024
244da67
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Jul 4, 2024
85ed4f5
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Jul 22, 2024
4ca50ed
Add tests
robinholzi Jul 24, 2024
91e0dda
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Jul 24, 2024
28dbb9f
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Jul 26, 2024
6b22b6d
fix linting
robinholzi Jul 26, 2024
0d42176
fix linting
robinholzi Jul 26, 2024
0028e1b
Implement suggestions, v1
robinholzi Aug 7, 2024
9f3b70f
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Aug 7, 2024
bbacf1e
Integrate suggestions, rename things, more tests, documentation
robinholzi Aug 8, 2024
44f29d1
Fix
robinholzi Aug 8, 2024
836af4f
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Aug 12, 2024
f950c25
Fix
robinholzi Aug 12, 2024
8912476
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Aug 12, 2024
c9091aa
Tests and adjustments to warmup
robinholzi Aug 12, 2024
4792fd4
Integrate suggestions, rename things
robinholzi Aug 12, 2024
ed0baaa
fix
robinholzi Aug 12, 2024
ccb6ba3
Fix
robinholzi Aug 13, 2024
46f52d3
Move averaging logic into detector
robinholzi Aug 13, 2024
a2392ad
Merge branch 'main' into robinholzi/feat/more-powerful-drift-windows-…
robinholzi Aug 13, 2024
ea036c1
Final adjustments
robinholzi Aug 14, 2024
9594453
Small fix wrt interval tests
robinholzi Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion analytics/app/pages/plots/eval_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def gen_figure(
# we only want the pipeline performance (composed of the models active periods stitched together)
df_adjusted = df_adjusted[df_adjusted[composite_model_variant]]
else:
assert df_adjusted["pipeline_ref"].nunique() == 1
assert df_adjusted["pipeline_ref"].nunique() <= 1
# add the pipeline time series which is the performance of different models stitched together dep.
# w.r.t which model was active
pipeline_composite_model = df_adjusted[df_adjusted[composite_model_variant]]
Expand Down
1 change: 0 additions & 1 deletion benchmark/huffpost_kaggle/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ In this directory, you can find the files necessary to run experiments using the
The goal is to predict the tag of news given headlines.
The dataset contains more than 60k samples collected from 2012 to 2018.
Titles belonging to the same year are grouped into the same CSV file and stored together.
Each year is mapped to a year starting from 1/1/1970.
There is a total of 42 categories/classes.

> Note: The wild-time variant of the huffpost dataset has only 11 classes. This is due to the fact that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ trigger:
metrics:
ev_mmd:
id: EvidentlyModelDriftMetric
threshold: 0.7
decision_criterion:
id: ThresholdDecisionCriterion
threshold: 0.7
aggregation_strategy:
id: MajorityVote
selection_strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ trigger:
ev_mmd:
id: AlibiDetectMmdDriftMetric
num_permutations: 1000
decision_criterion:
id: DynamicThresholdCriterion

aggregation_strategy:
id: MajorityVote
selection_strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ trigger:
metrics:
ev_mmd:
id: EvidentlyModelDriftMetric
threshold: 0.7
decision_criterion:
id: ThresholdDecisionCriterion
threshold: 0.7
aggregation_strategy:
id: MajorityVote
selection_strategy:
Expand Down
18 changes: 9 additions & 9 deletions docs/pipeline/TRIGGERING.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,36 @@ classDiagram
class TimeTrigger {
}

class DataAmount {
class DataAmountTrigger {
}

}

namespace complex_triggers {

class DataDrift {
class DataDriftTrigger {
}

class CostBased {
class CostBasedTrigger {
}

class _BatchedTrigger {
<<Abstract>>
}

class EnsemblePolicy {
class EnsembleTrigger {
}

}

Trigger <|-- _BatchedTrigger
Trigger <|-- EnsemblePolicy
Trigger <|-- EnsembleTrigger

Trigger <|-- TimeTrigger
Trigger <|-- DataAmount
Trigger <|-- DataAmountTrigger

_BatchedTrigger <|-- DataDrift
_BatchedTrigger <|-- CostBased
_BatchedTrigger <|-- DataDriftTrigger
_BatchedTrigger <|-- CostBasedTrigger

EnsemblePolicy *-- "n" Trigger
EnsembleTrigger *-- "n" Trigger
```
231 changes: 231 additions & 0 deletions docs/pipeline/triggering/DRIFT_TRIGGERS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Drift Triggering

## Overview

As can be seen in [TRIGGERING.md](../TRIGGERING.md) the `DataDriftTrigger` follows the same interface as the simple triggers. The main difference is that the `DataDriftTrigger` is a complex trigger that requires more information to be provided to the trigger.

It utilizes `DetectionWindows` to select samples for drift detection and `DriftDetector` to measure distance between the current and reference data distributions. The `DriftDecisionPolicy` is used to generate a binary decision based on the distance metric using a specific criterion like a threshold or dynamic threshold. Hypothesis testing can also be used to come to a decision, however, we found this to be oversensitive and not very useful in practice.

### Main Architecture

```mermaid
classDiagram
class DetectionWindows {
<<abstract>>
}

class DriftDetector

class Trigger {
<<abstract>>
+void init_trigger(TriggerContext context)
+Generator[Triggers] inform(new_data)
+void inform_previous_model(int previous_model_id)
}

class TimeTrigger {
}

class DataAmountTrigger {
}

class DriftDetector {
<<abstract>>
}

class DataDriftTrigger {
+DataDriftTriggerConfig config
+DetectionWindows _windows
+dict[MetricId, DriftDecisionPolicy] decision_policies
+dict[MetricId, DriftDetector] distance_detectors

+void init_trigger(TriggerContext context)
+Generator[triggers] inform(new_data)
+void inform_previous_model(int previous_model_id)
}

class DriftDecisionPolicy {
<<abstract>>
}

DataDriftTrigger *-- "1" DataDriftTriggerConfig
DataDriftTrigger *-- "1" DetectionWindows
DataDriftTrigger *-- "|metrics|" DriftDetector
DataDriftTrigger *-- "|metrics|" DriftDecisionPolicy

Trigger <|-- DataDriftTrigger
Trigger <|-- TimeTrigger
Trigger <|-- DataAmountTrigger
```

### DetectionWindows Hierarchy

The `DetectionWindows` class serves as the abstract base for specific windowing strategies like `AmountDetectionWindows` and `TimeDetectionWindows`. These classes are responsible for both storing the actual data windows for reference and current data and for defining a strategy for updating and managing these windows.

```mermaid
classDiagram
class DetectionWindows {
<<abstract>>
+Deque current
+Deque current_reservoir
+Deque reference
+void inform_data(list[tuple[int, int]])
+void inform_trigger()
}

class AmountDetectionWindows {
+AmountWindowingStrategy config
+void inform_data(list[tuple[int, int]])
+void inform_trigger()
}

class TimeDetectionWindows {
+TimeWindowingStrategy config
+void inform_data(list[tuple[int, int]])
+void inform_trigger()
}

DetectionWindows <|-- AmountDetectionWindows
DetectionWindows <|-- TimeDetectionWindows
```

### DriftDetector Hierarchy

The `DriftDetector` class is an abstract base class for detectors like `AlibiDriftDetector` and `EvidentlyDriftDetector`, which use different metrics to measure the distance between the current and reference data distributions. It generates distance values.
robinholzi marked this conversation as resolved.
Show resolved Hide resolved

The `BaseMetric` class hierarchy is a series of Pydantic configuration classes while the `Detectors` are actual business logic classes that implement the distance calculation.

```mermaid
classDiagram
class DriftDetector {
<<abstract>>
+dict[MetricId, DriftMetric] metrics_config
+void init_detector()
+dict[MetricId, MetricResult] detect_drift(embeddings_ref, embeddings_cur, bool is_warmup)
}

class AlibiDriftDetector
class EvidentlyDriftDetector
class BaseMetric {
decision_criterion: DecisionCriterion
}

DriftDetector <|-- AlibiDriftDetector
DriftDetector <|-- EvidentlyDriftDetector


AlibiDriftDetector *-- "|metrics|" AlibiDetectDriftMetric
EvidentlyDriftDetector *-- "|metrics|" EvidentlyDriftMetric

class AlibiDetectDriftMetric {
<<abstract>>
}

class AlibiDetectMmdDriftMetric {
}

class AlibiDetectCVMDriftMetric {
}

class AlibiDetectKSDriftMetric {
}

BaseMetric <|-- AlibiDetectDriftMetric
AlibiDetectDriftMetric <|-- AlibiDetectMmdDriftMetric
AlibiDetectDriftMetric <|-- AlibiDetectCVMDriftMetric
AlibiDetectDriftMetric <|-- AlibiDetectKSDriftMetric

class EvidentlyDriftMetric {
<<abstract>>
int num_pca_component
}

class EvidentlyModelDriftMetric {
bool bootstrap = False
float quantile_probability = 0.95
float threshold = 0.55
}

class EvidentlyRatioDriftMetric {
string component_stattest = "wasserstein"
float component_stattest_threshold = 0.1
float threshold = 0.2
}

class EvidentlySimpleDistanceDriftMetric {
string distance_metric = "euclidean"
bool bootstrap = False
float quantile_probability = 0.95
float threshold = 0.2
}

BaseMetric <|-- EvidentlyDriftMetric
EvidentlyDriftMetric <|-- EvidentlyModelDriftMetric
EvidentlyDriftMetric <|-- EvidentlyRatioDriftMetric
EvidentlyDriftMetric <|-- EvidentlySimpleDistanceDriftMetric

```

### DecisionCriterion Hierarchy

The `DecisionCriterion` class is an abstract configuration base class for criteria like `ThresholdDecisionCriterion` and `DynamicThresholdCriterion`, which define how decisions are made based on drift metrics.

```mermaid
classDiagram
class DecisionCriterion {
<<abstract>>
}

class ThresholdDecisionCriterion {
float threshold
}

class DynamicThresholdCriterion {
int window_size = 10
float percentile = 0.05
}

DecisionCriterion <|-- ThresholdDecisionCriterion
DecisionCriterion <|-- DynamicThresholdCriterion
```

### DriftDecisionPolicy Hierarchy

The `DriftDecisionPolicy` class is an abstract base class for policies like `ThresholdDecisionPolicy`, `DynamicDecisionPolicy`, and `HypothesisTestDecisionPolicy`.

Each Decision policy wraps one DriftMetric (e.g. MMD, CVM, KS) and one DecisionCriterion (e.g. Threshold, DynamicThreshold, HypothesisTest) to make a decision based on the distance metric. It e.g. observes the series of distance value measurements from it's `DriftMetric` and makes a decision after having calibrated on the seen distances.
robinholzi marked this conversation as resolved.
Show resolved Hide resolved
robinholzi marked this conversation as resolved.
Show resolved Hide resolved

Within one `DataDriftTrigger` the different results from different `DriftMetrics`'s `DriftDecisionPolicies` can be aggregated to a final decision using a voting mechanism (see `DataDriftTriggerConfig.aggregation_strategy`).

```mermaid
classDiagram
class DriftDecisionPolicy {
<<abstract>>
+bool evaluate_decision(float distance)
}

class ThresholdDecisionPolicy {
+ThresholdDecisionCriterion config
+bool evaluate_decision(float distance)
}

class DynamicDecisionPolicy {
+DynamicThresholdCriterion config
+Deque~float~ score_observations
+bool evaluate_decision(float distance)
}

class HypothesisTestDecisionPolicy {
+HypothesisTestCriterion config
+bool evaluate_decision(float distance)
}

DriftDecisionPolicy <|-- ThresholdDecisionPolicy
DriftDecisionPolicy <|-- DynamicDecisionPolicy
DriftDecisionPolicy <|-- HypothesisTestDecisionPolicy


style HypothesisTestDecisionPolicy fill:#DDDDDD,stroke:#A9A9A9,stroke-width:2px
```

This architecture provides a flexible framework for implementing various types of data drift detection mechanisms, different detection libraries, each with its own specific configuration, detection strategy, and decision-making criteria.
1 change: 1 addition & 0 deletions modyn/config/schema/pipeline/trigger/drift/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .aggregation import * # noqa
from .alibi_detect import * # noqa
from .config import * # noqa
from .detection_window import * # noqa
from .evidently import * # noqa
from .result import * # noqa
16 changes: 13 additions & 3 deletions modyn/config/schema/pipeline/trigger/drift/alibi_detect.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Note: we don't use the hypothesis testing in the alibi-detect metrics. However, we still keep
# the support for it in this wrapper configuration for offline experiments to still be able to
# use the hypothesis testing.

from typing import Annotated, Literal

from pydantic import Field, model_validator

from modyn.config.schema.base_model import ModynBaseModel
from modyn.config.schema.pipeline.trigger.drift.metric import BaseMetric


class _AlibiDetectBaseDriftMetric(ModynBaseModel):
class _AlibiDetectBaseDriftMetric(BaseMetric):
p_val: float = Field(0.05, description="The p-value threshold for the drift detection.")
x_ref_preprocessed: bool = Field(False)

Expand Down Expand Up @@ -36,7 +41,8 @@ class AlibiDetectMmdDriftMetric(_AlibiDetectBaseDriftMetric, AlibiDetectDeviceMi
),
)
kernel: str = Field(
"GaussianRBF", description="The kernel used for distance calculation imported from alibi_detect.utils.pytorch"
"GaussianRBF",
description="The kernel used for distance calculation imported from alibi_detect.utils.pytorch",
)
configure_kernel_from_x_ref: bool = Field(True)
threshold: float | None = Field(
Expand All @@ -62,7 +68,11 @@ def validate_threshold_permutations(self) -> "AlibiDetectMmdDriftMetric":
return self


class AlibiDetectKSDriftMetric(_AlibiDetectBaseDriftMetric, _AlibiDetectAlternativeMixin, _AlibiDetectCorrectionMixin):
class AlibiDetectKSDriftMetric(
_AlibiDetectBaseDriftMetric,
_AlibiDetectAlternativeMixin,
_AlibiDetectCorrectionMixin,
):
id: Literal["AlibiDetectKSDriftMetric"] = Field("AlibiDetectKSDriftMetric")


Expand Down
Loading
Loading