-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Refine triggering abstraction (#546)
- Loading branch information
1 parent
59d8ee9
commit b1b71c2
Showing
8 changed files
with
214 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Modyn Triggering | ||
|
||
Alongside the simple triggers, Modyn also provides complex triggers that can be used to trigger the training of a model. | ||
|
||
Despite the line being blurry, complex triggers are generally more sophisticated and require more information to be provided to the trigger - they most of the time cannot be pre-determined via a simple configuration entry but require | ||
reinjested information from the ongoing pipeline run. | ||
|
||
Some complex triggers can only make decisions in batched intervals as they are cannot be efficiently computed | ||
on a per-sample basis. Here we can find the `DataDrift` and `CostBased` (not yet implemented) based triggers. | ||
|
||
Another policy type is the `EnsemblePolicy` which can be used to combine multiple triggers into a single trigger. This can be useful if multiple triggers should be evaluated before the training of a model is triggered. | ||
One can either use pre-defined ensemble strategies like `MajorityVote` and `AtLeastNEnsembleStrategy` or define custom functions that reduce a list of trigger decisions (one per sub-policy) and make decisions on them freely via the CustomEnsembleStrategy. | ||
|
||
```mermaid | ||
classDiagram | ||
class Trigger { | ||
<<Abstract>> | ||
} | ||
namespace simple_triggers { | ||
class TimeTrigger { | ||
} | ||
class DataAmount { | ||
} | ||
} | ||
namespace complex_triggers { | ||
class DataDrift { | ||
} | ||
class CostBased { | ||
} | ||
class _BatchedTrigger { | ||
<<Abstract>> | ||
} | ||
class EnsemblePolicy { | ||
} | ||
} | ||
Trigger <|-- _BatchedTrigger | ||
Trigger <|-- EnsemblePolicy | ||
Trigger <|-- TimeTrigger | ||
Trigger <|-- DataAmount | ||
_BatchedTrigger <|-- DataDrift | ||
_BatchedTrigger <|-- CostBased | ||
EnsemblePolicy *-- "n" Trigger | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Annotated, Union | ||
|
||
from pydantic import Field | ||
|
||
from .data_amount import * # noqa | ||
from .data_amount import DataAmountTriggerConfig | ||
from .drift import * # noqa | ||
from .drift import DataDriftTriggerConfig | ||
from .ensemble import * # noqa | ||
from .ensemble import EnsembleTriggerConfig | ||
from .time import * # noqa | ||
from .time import TimeTriggerConfig | ||
|
||
TriggerConfig = Annotated[ | ||
Union[ | ||
TimeTriggerConfig, | ||
DataAmountTriggerConfig, | ||
DataDriftTriggerConfig, | ||
EnsembleTriggerConfig, | ||
], | ||
Field(discriminator="id"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import Literal | ||
|
||
from modyn.config.schema.base_model import ModynBaseModel | ||
from pydantic import Field | ||
|
||
|
||
class DataAmountTriggerConfig(ModynBaseModel): | ||
id: Literal["DataAmountTrigger"] = Field("DataAmountTrigger") | ||
num_samples: int = Field(description="The number of samples that should trigger the pipeline.", ge=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, ForwardRef, Literal, Optional | ||
|
||
from modyn.config.schema.base_model import ModynBaseModel | ||
from pydantic import Field | ||
|
||
__TriggerConfig = ForwardRef("TriggerConfig", is_class=True) | ||
|
||
|
||
class DataDriftTriggerConfig(ModynBaseModel): | ||
id: Literal["DataDriftTrigger"] = Field("DataDriftTrigger") | ||
|
||
detection_interval: Optional[__TriggerConfig] = Field( # type: ignore[valid-type] | ||
None, description="The Trigger policy to determine the interval at which drift detection is performed." | ||
) # currently not used | ||
|
||
detection_interval_data_points: int = Field( | ||
1000, description="The number of samples in the interval after which drift detection is performed.", ge=1 | ||
) | ||
sample_size: int | None = Field(None, description="The number of samples used for the metric calculation.", ge=1) | ||
metric: str = Field("model", description="The metric used for drift detection.") | ||
metric_config: dict[str, Any] = Field(default_factory=dict, description="Configuration for the evidently metric.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Annotated, Callable, ForwardRef, Literal, Union | ||
|
||
from modyn.config.schema.base_model import ModynBaseModel | ||
from pydantic import Field | ||
|
||
|
||
class BaseEnsembleStrategy(ModynBaseModel): | ||
|
||
@property | ||
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]: | ||
"""Returns: | ||
Function that aggregates the decisions of the individual triggers.""" | ||
raise NotImplementedError | ||
|
||
|
||
class MajorityVoteEnsembleStrategy(BaseEnsembleStrategy): | ||
id: Literal["MajorityVote"] = Field("MajorityVote") | ||
|
||
@property | ||
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]: | ||
return lambda decisions: sum(decisions.values()) > len(decisions) / 2 | ||
|
||
|
||
class AtLeastNEnsembleStrategy(BaseEnsembleStrategy): | ||
id: Literal["AtLeastN"] = Field("AtLeastN") | ||
|
||
n: int = Field(description="The minimum number of triggers that need to trigger for the ensemble to trigger.", ge=1) | ||
|
||
@property | ||
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]: | ||
return lambda decisions: sum(decisions.values()) >= self.n | ||
|
||
|
||
class CustomEnsembleStrategy(BaseEnsembleStrategy): | ||
id: Literal["Custom"] = Field("Custom") | ||
|
||
aggregation_function: Callable[[dict[str, bool]], bool] = Field( | ||
description="The function that aggregates the decisions of the individual triggers." | ||
) | ||
|
||
@property | ||
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]: | ||
return self.aggregation_function | ||
|
||
|
||
EnsembleStrategy = Annotated[ | ||
Union[ | ||
MajorityVoteEnsembleStrategy, | ||
AtLeastNEnsembleStrategy, | ||
CustomEnsembleStrategy, | ||
], | ||
Field(discriminator="id"), | ||
] | ||
|
||
__TriggerConfig = ForwardRef("TriggerConfig", is_class=True) | ||
|
||
|
||
class EnsembleTriggerConfig(ModynBaseModel): | ||
id: Literal["EnsembleTrigger"] = Field("EnsembleTrigger") | ||
|
||
policies: dict[str, __TriggerConfig] = Field( # type: ignore[valid-type] | ||
default_factory=dict, | ||
description="The policies keyed by distinct references that will be consulted for the ensemble trigger.", | ||
) | ||
|
||
ensemble_strategy: EnsembleStrategy = Field( | ||
description="The strategy that will be used to aggregate the decisions of the individual triggers." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from functools import cached_property | ||
from typing import Literal | ||
|
||
from modyn.config.schema.base_model import ModynBaseModel | ||
from modyn.const.regex import REGEX_TIME_UNIT | ||
from modyn.utils.utils import SECONDS_PER_UNIT | ||
from pydantic import Field | ||
|
||
|
||
class TimeTriggerConfig(ModynBaseModel): | ||
id: Literal["TimeTrigger"] = Field("TimeTrigger") | ||
every: str = Field( | ||
description="Interval length for the trigger as an integer followed by a time unit: s, m, h, d, w, y", | ||
pattern=rf"^\d+{REGEX_TIME_UNIT}$", | ||
) | ||
start_timestamp: int | None = Field( | ||
None, | ||
description=( | ||
"The timestamp at which the triggering schedule starts. First trigger will be at start_timestamp + every." | ||
"Use None to start at the first timestamp of the data." | ||
), | ||
) | ||
|
||
@cached_property | ||
def every_seconds(self) -> int: | ||
unit = str(self.every)[-1:] | ||
num = int(str(self.every)[:-1]) | ||
return num * SECONDS_PER_UNIT[unit] |