Skip to content

Commit

Permalink
feat: Add Avoidable misclassification based CostTrigger (#605)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Aug 30, 2024
1 parent b37fc13 commit ecdd576
Show file tree
Hide file tree
Showing 19 changed files with 738 additions and 316 deletions.
7 changes: 6 additions & 1 deletion modyn/config/schema/pipeline/trigger/cost/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from modyn.config.schema.pipeline.trigger.common.batched import BatchedTriggerConfig

from ..performance.criterion import _NumberAvoidableMisclassificationCriterion
from ..performance.performance import _InternalPerformanceTriggerConfig


Expand Down Expand Up @@ -59,7 +60,11 @@ def conversion_factor(self) -> float:
return self.incorporation_delay_per_training_second


class AvoidableMisclassificationCostTriggerConfig(_CostTriggerConfig, _InternalPerformanceTriggerConfig):
class AvoidableMisclassificationCostTriggerConfig(
_CostTriggerConfig,
_InternalPerformanceTriggerConfig,
_NumberAvoidableMisclassificationCriterion,
):
"""Cost aware trigger policy configuration that using the number of
avoidable misclassifications integration latency as a regret metric.
Expand Down
14 changes: 7 additions & 7 deletions modyn/config/schema/pipeline/trigger/performance/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ class _NumberAvoidableMisclassificationCriterion(ModynBaseModel):
"If not set, the expected performance will be inferred dynamically with a rolling average."
),
)


class StaticNumberAvoidableMisclassificationCriterion(_NumberAvoidableMisclassificationCriterion):
id: Literal["StaticNumberMisclassificationCriterion"] = Field("StaticNumberMisclassificationCriterion")
avoidable_misclassification_threshold: int = Field(
description="The threshold for the misclassification rate that will invoke a trigger."
)
allow_reduction: bool = Field(
False,
description=(
Expand All @@ -84,6 +77,13 @@ class StaticNumberAvoidableMisclassificationCriterion(_NumberAvoidableMisclassif
)


class StaticNumberAvoidableMisclassificationCriterion(_NumberAvoidableMisclassificationCriterion):
id: Literal["StaticNumberMisclassificationCriterion"] = Field("StaticNumberMisclassificationCriterion")
avoidable_misclassification_threshold: int = Field(
description="The threshold for the misclassification rate that will invoke a trigger."
)


# -------------------------------------------------------------------------------------------------------------------- #
# Union #
# -------------------------------------------------------------------------------------------------------------------- #
Expand Down
32 changes: 16 additions & 16 deletions modyn/config/schema/pipeline/trigger/performance/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def validate_metrics(cls, dataset: EvalDataConfig) -> EvalDataConfig:

class _InternalPerformanceTriggerConfig(BatchedTriggerConfig):
data_density_window_size: int = Field(
0,
20,
description="The window size for the data density estimation. Only used for lookahead mode.",
)
performance_triggers_window_size: int = Field(
Expand All @@ -43,17 +43,6 @@ class _InternalPerformanceTriggerConfig(BatchedTriggerConfig):
description="Configuration for the evaluation of the performance trigger."
)

decision_criteria: dict[str, PerformanceTriggerCriterion] = Field(
description=(
"The decision criteria to be used for the performance trigger. If any of the criteria is met, "
"the trigger will be executed. The criteria will be evaluated in the order they are defined. "
"Every criterion is linked to a metric. Some of the criteria implicitly only work on accuracy which is "
"the default metric that is always generated and cannot be disabled. To define a "
"`StaticPerformanceThresholdCriterion` on Accuracy, the evaluation config has to define the accuracy metric."
),
min_length=1,
)

mode: TriggerEvaluationMode = Field(
"hindsight",
description="Whether to also consider forecasted future performance in the drift decision.",
Expand All @@ -66,6 +55,21 @@ class _InternalPerformanceTriggerConfig(BatchedTriggerConfig):
),
)


class PerformanceTriggerConfig(_InternalPerformanceTriggerConfig):
id: Literal["PerformanceTrigger"] = Field("PerformanceTrigger")

decision_criteria: dict[str, PerformanceTriggerCriterion] = Field(
description=(
"The decision criteria to be used for the performance trigger. If any of the criteria is met, "
"the trigger will be executed. The criteria will be evaluated in the order they are defined. "
"Every criterion is linked to a metric. Some of the criteria implicitly only work on accuracy which is "
"the default metric that is always generated and cannot be disabled. To define a "
"`StaticPerformanceThresholdCriterion` on Accuracy, the evaluation config has to define the accuracy metric."
),
min_length=1,
)

@model_validator(mode="after")
def validate_decision_criteria(self) -> "PerformanceTriggerConfig":
"""Assert that all criteria use metrics that are defined in the
Expand All @@ -77,7 +81,3 @@ def validate_decision_criteria(self) -> "PerformanceTriggerConfig":
f"Criterion {criterion.id} uses metric {criterion.metric} which is not defined in the evaluation config."
)
return self


class PerformanceTriggerConfig(_InternalPerformanceTriggerConfig):
id: Literal["PerformanceTrigger"] = Field("PerformanceTrigger")
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging

from typing_extensions import override

from modyn.config.schema.pipeline.trigger.cost.cost import (
AvoidableMisclassificationCostTriggerConfig,
)
from modyn.supervisor.internal.triggers.costtrigger import CostTrigger
from modyn.supervisor.internal.triggers.performance.misclassification_estimator import (
NumberAvoidableMisclassificationEstimator,
)
from modyn.supervisor.internal.triggers.performance.performancetrigger_mixin import (
PerformanceTriggerMixin,
)
from modyn.supervisor.internal.triggers.trigger import TriggerContext

logger = logging.getLogger(__name__)


class AvoidableMisclassificationCostTrigger(CostTrigger, PerformanceTriggerMixin):
"""Triggers when the avoidable misclassification cost incorporation latency
(regret) exceeds the estimated training time."""

def __init__(self, config: AvoidableMisclassificationCostTriggerConfig):
CostTrigger.__init__(self, config)
PerformanceTriggerMixin.__init__(self, config)

self.config = config
self.context: TriggerContext | None = None

self.misclassification_estimator = NumberAvoidableMisclassificationEstimator(
config.expected_accuracy, config.allow_reduction
)

@override
def init_trigger(self, context: TriggerContext) -> None:
# Call CostTrigger's init_trigger method to initialize the trigger context
CostTrigger.init_trigger(self, context)

# Call PerformanceTriggerMixin's init_trigger method to initialize the internal performance detection state
PerformanceTriggerMixin._init_trigger(self, context)

@override
def inform_new_model(
self,
most_recent_model_id: int,
number_samples: int | None = None,
training_time: float | None = None,
) -> None:
"""Update the cost and performance trackers with the new model
metadata."""

# Call CostTrigger's inform_new_model method to update the cost tracker
CostTrigger.inform_new_model(self, most_recent_model_id, number_samples, training_time)

# Call the internal PerformanceTriggerMixin's inform_new_model method to update the performance tracker
PerformanceTriggerMixin._inform_new_model(self, most_recent_model_id, self._last_detection_interval)

# ---------------------------------------------------------------------------------------------------------------- #
# INTERNAL #
# ---------------------------------------------------------------------------------------------------------------- #

@override
def _compute_regret_metric(self, batch: list[tuple[int, int]], batch_start: int, batch_duration: int) -> float:
"""Compute the regret metric for the current state of the trigger."""

self.data_density.inform_data(batch)
num_samples, num_misclassifications, evaluation_scores = self._run_evaluation(interval_data=batch)

self.performance_tracker.inform_evaluation(
num_samples=num_samples,
num_misclassifications=num_misclassifications,
evaluation_scores=evaluation_scores,
)

estimated_new_avoidable_misclassifications, _ = (
self.misclassification_estimator.estimate_avoidable_misclassifications(
update_interval_samples=self.config.evaluation_interval_data_points,
data_density=self.data_density,
performance_tracker=self.performance_tracker,
method=self.config.forecasting_method,
)
)

# Let's build a latency regret metrics based on the estimated number of avoidable misclassifications.
# Using avoidable misclassification latency makes sense because we generally aim to to trigger
# when many misclassifications could have been avoided. We therefore try to minimize the time that
# misclassifications remain unaddressed while an old model is still in use.
# We chose the latency based area under curve method as a linear model based on the absolute number of
# avoidable misclassifications seems unstable. More advantageous regret non-linear regret functions
# could be explored in the future.
return self._incorporation_latency_tracker.add_latency(
estimated_new_avoidable_misclassifications, batch_duration
)
2 changes: 2 additions & 0 deletions modyn/supervisor/internal/triggers/batchedtrigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, config: BatchedTriggerConfig) -> None:

# allows to detect drift in a fixed interval
self._sample_left_until_detection = config.evaluation_interval_data_points
self._last_detection_interval: list[tuple[int, int]] = []

self._leftover_data: list[tuple[int, int]] = []
"""Stores data that was not processed in the last inform call because
Expand Down Expand Up @@ -69,6 +70,7 @@ def inform(
# ----------------------------------------------- Detection ---------------------------------------------- #

triggered = self._evaluate_batch(next_detection_interval, trigger_candidate_idx)
self._last_detection_interval = next_detection_interval

# ----------------------------------------------- Response ----------------------------------------------- #

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ def add_latency(self, regret: float, batch_duration: float) -> float:
Returns:
Most recent cumulative regret value.
"""

# newly arrived `regret` has existed for `batch_duration / 2` seconds on average;
# old regret persists for the entire `batch_duration`
self._cumulative_latency_regret += self._current_regret * batch_duration + regret * (batch_duration / 2.0)
self._current_regret += regret
self._cumulative_latency_regret += self._current_regret * batch_duration

return self._cumulative_latency_regret

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@ def __init__(self, config: CostTriggerConfig):
self.config = config
self.context: TriggerContext | None = None

self._sample_left_until_detection = config.evaluation_interval_data_points
self._triggered_once = False
self._previous_batch_end_time: int | None = None
self._leftover_data: list[tuple[int, int]] = []
"""Stores data that was not processed in the last inform call because
the detection interval was not filled."""

# cost information
self._unincorporated_samples = 0
self._cost_tracker = CostTracker(config.cost_tracking_window_size)
self._incorporation_latency_tracker = IncorporationLatencyTracker()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from modyn.config.schema.pipeline.trigger.cost.cost import (
AvoidableMisclassificationCostTriggerConfig,
)
from modyn.supervisor.internal.triggers.cost.costtrigger import CostTrigger
from modyn.supervisor.internal.triggers.costtrigger import CostTrigger

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit ecdd576

Please sign in to comment.