diff --git a/modyn/config/schema/pipeline/trigger/drift/config.py b/modyn/config/schema/pipeline/trigger/drift/config.py index ef9f9ef2e..7629885c3 100644 --- a/modyn/config/schema/pipeline/trigger/drift/config.py +++ b/modyn/config/schema/pipeline/trigger/drift/config.py @@ -28,6 +28,14 @@ class DataDriftTriggerConfig(BatchedTriggerConfig): description="Which windowing strategy to use for current and reference data", ) + sample_size: int | None = Field( + 5000, + description=( + "The number of samples to use for drift detection. If the windows are bigger than this, " + "samples are randomly drawn from the window. None does not limit the number of samples." + ), + ) + metrics: dict[str, DriftMetric] = Field( min_length=1, description="The metrics used for drift detection keyed by a reference.", diff --git a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py index 07f3fff54..1fccc2b9f 100644 --- a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py @@ -869,7 +869,9 @@ def _post_pipeline_evaluation_checkpoint(self, s: ExecutionState, log: StageLog) return self.logs.materialize(s.log_directory, mode="increment") - self.eval_executor.register_tracking_info(tracking_dfs=s.tracking, dataset_end_time=self.state.max_timestamp) + self.eval_executor.register_tracking_info( + tracking_dfs=s.tracking, dataset_end_time=self.state.current_sample_time + ) self.eval_executor.create_snapshot() @pipeline_stage(PipelineStage.POST_EVALUATION, parent=PipelineStage.MAIN, log=False, track=False) diff --git a/modyn/supervisor/internal/triggers/datadrifttrigger.py b/modyn/supervisor/internal/triggers/datadrifttrigger.py index d3613facd..bb0912965 100644 --- a/modyn/supervisor/internal/triggers/datadrifttrigger.py +++ b/modyn/supervisor/internal/triggers/datadrifttrigger.py @@ -259,10 +259,16 @@ def _run_detection( assert len(current) > 0 reference_dataloader = prepare_trigger_dataloader_fixed_keys( - self.dataloader_info, [key for key, _ in reference] + self.dataloader_info, + [key for key, _ in reference], + sample_size=self.config.sample_size, ) - current_dataloader = prepare_trigger_dataloader_fixed_keys(self.dataloader_info, [key for key, _ in current]) + current_dataloader = prepare_trigger_dataloader_fixed_keys( + self.dataloader_info, + [key for key, _ in current], + sample_size=self.config.sample_size, + ) # Download most recent model as stateful model # TODO(417) Support custom model as stateful model