Skip to content

Commit 6f117ba

Browse files
authored
Integrate Concept Drift (#108)
1 parent f5967d7 commit 6f117ba

File tree

5 files changed

+43
-29
lines changed

5 files changed

+43
-29
lines changed

src/analytics/drift/pipelines.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pandas as pd
2-
from typing import Dict, Union, Any
32
import json
43
from evidently.report import Report
54
from evidently.metric_preset import DataDriftPreset, TargetDriftPreset
6-
from src.schemas.driftingMetric import DataDriftTable
5+
from src.schemas.driftingMetric import DataDriftTable, ConceptDriftTable
76

87

98
def run_data_drift_pipeline(
@@ -54,15 +53,14 @@ def run_data_drift_pipeline(
5453
initial_report = json.loads(initial_report)
5554

5655
data_drift_report = {}
57-
data_drift_report["timestamp"] = initial_report["timestamp"]
5856
data_drift_report["drift_summary"] = initial_report["metrics"][1]["result"]
5957

6058
return DataDriftTable(**data_drift_report["drift_summary"])
6159

6260

6361
def run_concept_drift_pipeline(
6462
reference_dataset: pd.DataFrame, current_dataset: pd.DataFrame, target_feature: str
65-
) -> Dict[str, Union[TargetDriftPreset, str]]:
63+
) -> ConceptDriftTable:
6664
"""
6765
To estimate the categorical target drift, we compare the distribution of the target in the two datasets.
6866
This solution works for both binary and multi-class classification.
@@ -89,10 +87,12 @@ def run_concept_drift_pipeline(
8987
initial_report = drift_report.json()
9088
initial_report = json.loads(initial_report)
9189
concept_drift_report = {}
92-
concept_drift_report["timestamp"] = initial_report["timestamp"]
9390
concept_drift_report["concept_drift_summary"] = initial_report["metrics"][0][
9491
"result"
9592
]
9693
concept_drift_report["column_correlation"] = initial_report["metrics"][1]["result"]
9794

98-
return concept_drift_report
95+
return ConceptDriftTable(
96+
concept_drift_summary=concept_drift_report["concept_drift_summary"],
97+
column_correlation=concept_drift_report["column_correlation"],
98+
)

src/analytics/tests/test_pipelines.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -203,36 +203,44 @@ def test_create_data_drift_pipeline(self):
203203
)
204204

205205
def test_create_concept_drift_pipeline_drift_not_detected(self):
206-
concept_drift_report = run_concept_drift_pipeline(
207-
reference_concept_drift, current_concept_drift, "y_testing_multi"
206+
concept_drift_report = vars(
207+
run_concept_drift_pipeline(
208+
reference_concept_drift, current_concept_drift, "y_testing_multi"
209+
)
208210
)
209211
assert list(concept_drift_report.keys()) == [
210-
"timestamp",
211212
"concept_drift_summary",
212213
"column_correlation",
213214
]
214215
assert (
215-
round(concept_drift_report["concept_drift_summary"]["drift_score"], 3)
216+
round(vars(concept_drift_report["concept_drift_summary"])["drift_score"], 3)
216217
== 0.082
217218
)
218-
assert concept_drift_report["concept_drift_summary"]["drift_detected"] == False
219+
assert (
220+
vars(concept_drift_report["concept_drift_summary"])["drift_detected"]
221+
== False
222+
)
219223

220224
def test_create_concept_drift_pipeline_drift_detected(self):
221-
concept_drift_report = run_concept_drift_pipeline(
222-
reference_concept_drift_detected,
223-
current_concept_drift_detected,
224-
"discount_price__currency",
225+
concept_drift_report = vars(
226+
run_concept_drift_pipeline(
227+
reference_concept_drift_detected,
228+
current_concept_drift_detected,
229+
"discount_price__currency",
230+
)
225231
)
226232
assert list(concept_drift_report.keys()) == [
227-
"timestamp",
228233
"concept_drift_summary",
229234
"column_correlation",
230235
]
231236
assert (
232-
round(concept_drift_report["concept_drift_summary"]["drift_score"], 3)
237+
round(vars(concept_drift_report["concept_drift_summary"])["drift_score"], 3)
233238
== 0.008
234239
)
235-
assert concept_drift_report["concept_drift_summary"]["drift_detected"] == True
240+
assert (
241+
vars(concept_drift_report["concept_drift_summary"])["drift_detected"]
242+
== True
243+
)
236244

237245
def test_create_binary_classification_training_model_pipeline(self):
238246
model, eval = create_binary_classification_training_model_pipeline(

src/cron_tasks/monitoring_metrics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from sqlalchemy.orm import sessionmaker, Session
66

77
from src import crud, entities
8-
from src.analytics.drift.pipelines import run_data_drift_pipeline
8+
from src.analytics.drift.pipelines import (
9+
run_data_drift_pipeline,
10+
run_concept_drift_pipeline,
11+
)
912
from src.analytics.metrics.pipelines import (
1013
create_binary_classification_evaluation_metrics_pipeline,
1114
create_feature_metrics_pipeline,
@@ -56,10 +59,16 @@ async def run_calculate_drifting_metrics_pipeline(
5659
data_drift_report = run_data_drift_pipeline(
5760
processed_training_dropped_target_df, processed_inference_dropped_target_df
5861
)
62+
concept_drift_report = run_concept_drift_pipeline(
63+
training_processed_df,
64+
inference_processed_df,
65+
model.prediction,
66+
)
5967

6068
new_drifting_metric = entities.DriftingMetric(
6169
timestamp=str(datetime.utcnow()),
6270
model_id=model.id,
71+
concept_drift_summary=concept_drift_report,
6372
data_drift_summary=data_drift_report,
6473
)
6574

src/entities/DriftingMetric.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import Column, Float, ForeignKey, String, DateTime, JSON
1+
from sqlalchemy import Column, ForeignKey, String, DateTime, JSON
22
from src.entities.Base import Base
33
from src.utils.id_gen import generate_uuid
44

@@ -9,8 +9,7 @@ class DriftingMetric(Base):
99
id = Column(String, primary_key=True, unique=True, default=generate_uuid)
1010
model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE"))
1111
timestamp = Column(DateTime)
12-
# TODO: Fix pipeline to return a DataDriftTable first
13-
# concept_drift_summary = Column(JSON)
12+
concept_drift_summary = Column(JSON)
1413
data_drift_summary = Column(JSON)
1514
created_at = Column(DateTime)
1615
updated_at = Column(DateTime)

src/schemas/driftingMetric.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Dict, Union
2+
from typing import Dict, List, Union
33
from pydantic import BaseModel
44
from src.schemas.base import ItemBase
55

@@ -28,15 +28,15 @@ class CramerV(BaseModel):
2828

2929
column_name: str
3030
kind: str
31-
values: Dict[str, Dict[str, str]]
31+
values: Dict[str, List[str]]
3232

3333

3434
class ColumnConceptDriftCorrelationMetrics(BaseModel):
3535
"""One column concept drift correlation metrics"""
3636

3737
column_name: str
38-
current: CramerV
39-
reference: CramerV
38+
current: Dict[str, CramerV]
39+
reference: Dict[str, CramerV]
4040

4141

4242
class ColumnConceptDriftMetrics(BaseModel):
@@ -57,12 +57,10 @@ class ConceptDriftTable(BaseModel):
5757
column_correlation: ColumnConceptDriftCorrelationMetrics
5858

5959

60-
# TODO: Need to include the class of the concept drift
6160
class DriftingMetricBase(ItemBase):
6261
model_id: str
6362
timestamp: Union[str, datetime]
64-
# TODO: The pipeline needs to return a DataDriftTable. If evidently does not provide it we should create it.
65-
# concept_drift_summary: DataDriftTable
63+
concept_drift_summary: ConceptDriftTable
6664
data_drift_summary: DataDriftTable
6765

6866

0 commit comments

Comments
 (0)