-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Avoidable misclassification based
CostTrigger
(#605)
- Loading branch information
1 parent
b37fc13
commit ecdd576
Showing
19 changed files
with
738 additions
and
316 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
modyn/supervisor/internal/triggers/avoidablemissclassification_costtrigger.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import logging | ||
|
||
from typing_extensions import override | ||
|
||
from modyn.config.schema.pipeline.trigger.cost.cost import ( | ||
AvoidableMisclassificationCostTriggerConfig, | ||
) | ||
from modyn.supervisor.internal.triggers.costtrigger import CostTrigger | ||
from modyn.supervisor.internal.triggers.performance.misclassification_estimator import ( | ||
NumberAvoidableMisclassificationEstimator, | ||
) | ||
from modyn.supervisor.internal.triggers.performance.performancetrigger_mixin import ( | ||
PerformanceTriggerMixin, | ||
) | ||
from modyn.supervisor.internal.triggers.trigger import TriggerContext | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AvoidableMisclassificationCostTrigger(CostTrigger, PerformanceTriggerMixin): | ||
"""Triggers when the avoidable misclassification cost incorporation latency | ||
(regret) exceeds the estimated training time.""" | ||
|
||
def __init__(self, config: AvoidableMisclassificationCostTriggerConfig): | ||
CostTrigger.__init__(self, config) | ||
PerformanceTriggerMixin.__init__(self, config) | ||
|
||
self.config = config | ||
self.context: TriggerContext | None = None | ||
|
||
self.misclassification_estimator = NumberAvoidableMisclassificationEstimator( | ||
config.expected_accuracy, config.allow_reduction | ||
) | ||
|
||
@override | ||
def init_trigger(self, context: TriggerContext) -> None: | ||
# Call CostTrigger's init_trigger method to initialize the trigger context | ||
CostTrigger.init_trigger(self, context) | ||
|
||
# Call PerformanceTriggerMixin's init_trigger method to initialize the internal performance detection state | ||
PerformanceTriggerMixin._init_trigger(self, context) | ||
|
||
@override | ||
def inform_new_model( | ||
self, | ||
most_recent_model_id: int, | ||
number_samples: int | None = None, | ||
training_time: float | None = None, | ||
) -> None: | ||
"""Update the cost and performance trackers with the new model | ||
metadata.""" | ||
|
||
# Call CostTrigger's inform_new_model method to update the cost tracker | ||
CostTrigger.inform_new_model(self, most_recent_model_id, number_samples, training_time) | ||
|
||
# Call the internal PerformanceTriggerMixin's inform_new_model method to update the performance tracker | ||
PerformanceTriggerMixin._inform_new_model(self, most_recent_model_id, self._last_detection_interval) | ||
|
||
# ---------------------------------------------------------------------------------------------------------------- # | ||
# INTERNAL # | ||
# ---------------------------------------------------------------------------------------------------------------- # | ||
|
||
@override | ||
def _compute_regret_metric(self, batch: list[tuple[int, int]], batch_start: int, batch_duration: int) -> float: | ||
"""Compute the regret metric for the current state of the trigger.""" | ||
|
||
self.data_density.inform_data(batch) | ||
num_samples, num_misclassifications, evaluation_scores = self._run_evaluation(interval_data=batch) | ||
|
||
self.performance_tracker.inform_evaluation( | ||
num_samples=num_samples, | ||
num_misclassifications=num_misclassifications, | ||
evaluation_scores=evaluation_scores, | ||
) | ||
|
||
estimated_new_avoidable_misclassifications, _ = ( | ||
self.misclassification_estimator.estimate_avoidable_misclassifications( | ||
update_interval_samples=self.config.evaluation_interval_data_points, | ||
data_density=self.data_density, | ||
performance_tracker=self.performance_tracker, | ||
method=self.config.forecasting_method, | ||
) | ||
) | ||
|
||
# Let's build a latency regret metrics based on the estimated number of avoidable misclassifications. | ||
# Using avoidable misclassification latency makes sense because we generally aim to to trigger | ||
# when many misclassifications could have been avoided. We therefore try to minimize the time that | ||
# misclassifications remain unaddressed while an old model is still in use. | ||
# We chose the latency based area under curve method as a linear model based on the absolute number of | ||
# avoidable misclassifications seems unstable. More advantageous regret non-linear regret functions | ||
# could be explored in the future. | ||
return self._incorporation_latency_tracker.add_latency( | ||
estimated_new_avoidable_misclassifications, batch_duration | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 0 additions & 50 deletions
50
modyn/supervisor/internal/triggers/cost/avoidablemissclassification_costtrigger.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.