Skip to content

Commit

Permalink
feat: Refine triggering abstraction (#546)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Jun 24, 2024
1 parent 59d8ee9 commit b1b71c2
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 49 deletions.
59 changes: 59 additions & 0 deletions docs/pipeline/TRIGGERING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Modyn Triggering

Alongside the simple triggers, Modyn also provides complex triggers that can be used to trigger the training of a model.

Despite the line being blurry, complex triggers are generally more sophisticated and require more information to be provided to the trigger - they most of the time cannot be pre-determined via a simple configuration entry but require
reinjested information from the ongoing pipeline run.

Some complex triggers can only make decisions in batched intervals as they are cannot be efficiently computed
on a per-sample basis. Here we can find the `DataDrift` and `CostBased` (not yet implemented) based triggers.

Another policy type is the `EnsemblePolicy` which can be used to combine multiple triggers into a single trigger. This can be useful if multiple triggers should be evaluated before the training of a model is triggered.
One can either use pre-defined ensemble strategies like `MajorityVote` and `AtLeastNEnsembleStrategy` or define custom functions that reduce a list of trigger decisions (one per sub-policy) and make decisions on them freely via the CustomEnsembleStrategy.

```mermaid
classDiagram
class Trigger {
<<Abstract>>
}
namespace simple_triggers {
class TimeTrigger {
}
class DataAmount {
}
}
namespace complex_triggers {
class DataDrift {
}
class CostBased {
}
class _BatchedTrigger {
<<Abstract>>
}
class EnsemblePolicy {
}
}
Trigger <|-- _BatchedTrigger
Trigger <|-- EnsemblePolicy
Trigger <|-- TimeTrigger
Trigger <|-- DataAmount
_BatchedTrigger <|-- DataDrift
_BatchedTrigger <|-- CostBased
EnsemblePolicy *-- "n" Trigger
```
2 changes: 1 addition & 1 deletion modyn/config/schema/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModynPipelineConfig(ModynBaseModel):

training: TrainingConfig
data: DataConfig
trigger: TriggerConfig
trigger: TriggerConfig # type: ignore[valid-type]
selection_strategy: SelectionStrategy
evaluation: EvaluationConfig | None = Field(None)

Expand Down
48 changes: 0 additions & 48 deletions modyn/config/schema/pipeline/trigger.py

This file was deleted.

24 changes: 24 additions & 0 deletions modyn/config/schema/pipeline/trigger/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Annotated, Union

from pydantic import Field

from .data_amount import * # noqa
from .data_amount import DataAmountTriggerConfig
from .drift import * # noqa
from .drift import DataDriftTriggerConfig
from .ensemble import * # noqa
from .ensemble import EnsembleTriggerConfig
from .time import * # noqa
from .time import TimeTriggerConfig

TriggerConfig = Annotated[
Union[
TimeTriggerConfig,
DataAmountTriggerConfig,
DataDriftTriggerConfig,
EnsembleTriggerConfig,
],
Field(discriminator="id"),
]
9 changes: 9 additions & 0 deletions modyn/config/schema/pipeline/trigger/data_amount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Literal

from modyn.config.schema.base_model import ModynBaseModel
from pydantic import Field


class DataAmountTriggerConfig(ModynBaseModel):
id: Literal["DataAmountTrigger"] = Field("DataAmountTrigger")
num_samples: int = Field(description="The number of samples that should trigger the pipeline.", ge=1)
23 changes: 23 additions & 0 deletions modyn/config/schema/pipeline/trigger/drift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from typing import Any, ForwardRef, Literal, Optional

from modyn.config.schema.base_model import ModynBaseModel
from pydantic import Field

__TriggerConfig = ForwardRef("TriggerConfig", is_class=True)


class DataDriftTriggerConfig(ModynBaseModel):
id: Literal["DataDriftTrigger"] = Field("DataDriftTrigger")

detection_interval: Optional[__TriggerConfig] = Field( # type: ignore[valid-type]
None, description="The Trigger policy to determine the interval at which drift detection is performed."
) # currently not used

detection_interval_data_points: int = Field(
1000, description="The number of samples in the interval after which drift detection is performed.", ge=1
)
sample_size: int | None = Field(None, description="The number of samples used for the metric calculation.", ge=1)
metric: str = Field("model", description="The metric used for drift detection.")
metric_config: dict[str, Any] = Field(default_factory=dict, description="Configuration for the evidently metric.")
70 changes: 70 additions & 0 deletions modyn/config/schema/pipeline/trigger/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import Annotated, Callable, ForwardRef, Literal, Union

from modyn.config.schema.base_model import ModynBaseModel
from pydantic import Field


class BaseEnsembleStrategy(ModynBaseModel):

@property
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]:
"""Returns:
Function that aggregates the decisions of the individual triggers."""
raise NotImplementedError


class MajorityVoteEnsembleStrategy(BaseEnsembleStrategy):
id: Literal["MajorityVote"] = Field("MajorityVote")

@property
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]:
return lambda decisions: sum(decisions.values()) > len(decisions) / 2


class AtLeastNEnsembleStrategy(BaseEnsembleStrategy):
id: Literal["AtLeastN"] = Field("AtLeastN")

n: int = Field(description="The minimum number of triggers that need to trigger for the ensemble to trigger.", ge=1)

@property
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]:
return lambda decisions: sum(decisions.values()) >= self.n


class CustomEnsembleStrategy(BaseEnsembleStrategy):
id: Literal["Custom"] = Field("Custom")

aggregation_function: Callable[[dict[str, bool]], bool] = Field(
description="The function that aggregates the decisions of the individual triggers."
)

@property
def aggregate_decision_func(self) -> Callable[[dict[str, bool]], bool]:
return self.aggregation_function


EnsembleStrategy = Annotated[
Union[
MajorityVoteEnsembleStrategy,
AtLeastNEnsembleStrategy,
CustomEnsembleStrategy,
],
Field(discriminator="id"),
]

__TriggerConfig = ForwardRef("TriggerConfig", is_class=True)


class EnsembleTriggerConfig(ModynBaseModel):
id: Literal["EnsembleTrigger"] = Field("EnsembleTrigger")

policies: dict[str, __TriggerConfig] = Field( # type: ignore[valid-type]
default_factory=dict,
description="The policies keyed by distinct references that will be consulted for the ensemble trigger.",
)

ensemble_strategy: EnsembleStrategy = Field(
description="The strategy that will be used to aggregate the decisions of the individual triggers."
)
28 changes: 28 additions & 0 deletions modyn/config/schema/pipeline/trigger/time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from functools import cached_property
from typing import Literal

from modyn.config.schema.base_model import ModynBaseModel
from modyn.const.regex import REGEX_TIME_UNIT
from modyn.utils.utils import SECONDS_PER_UNIT
from pydantic import Field


class TimeTriggerConfig(ModynBaseModel):
id: Literal["TimeTrigger"] = Field("TimeTrigger")
every: str = Field(
description="Interval length for the trigger as an integer followed by a time unit: s, m, h, d, w, y",
pattern=rf"^\d+{REGEX_TIME_UNIT}$",
)
start_timestamp: int | None = Field(
None,
description=(
"The timestamp at which the triggering schedule starts. First trigger will be at start_timestamp + every."
"Use None to start at the first timestamp of the data."
),
)

@cached_property
def every_seconds(self) -> int:
unit = str(self.every)[-1:]
num = int(str(self.every)[:-1])
return num * SECONDS_PER_UNIT[unit]

0 comments on commit b1b71c2

Please sign in to comment.