Skip to content

Commit cb969d6

Browse files
committed
Add logging, pre-commit hooks, and transactional clearing of the db before insertion
1 parent 89010a2 commit cb969d6

18 files changed

+317
-132
lines changed

.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ CLICKHOUSE_HOST=localhost
22
CLICKHOUSE_PORT=9000
33
CLICKHOUSE_USER=default
44
CLICKHOUSE_PASSWORD=
5-
CLICKHOUSE_DATABASE=test_db
5+
CLICKHOUSE_DATABASE=test_db

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ venv
44
analytics-obfuscated-faked.csv
55
__pycache__/
66
remove_pycache.sh
7-
tasks.txt
7+
tasks.txt

.pre-commit-config.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# .pre-commit-config.yaml
2+
repos:
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v4.5.0
5+
hooks:
6+
- id: trailing-whitespace
7+
- id: end-of-file-fixer
8+
- id: check-yaml
9+
- id: check-added-large-files
10+
11+
- repo: https://github.com/psf/black
12+
rev: 24.4.2
13+
hooks:
14+
- id: black
15+
args: ['--line-length=79']
16+
17+
- repo: https://github.com/charliermarsh/ruff-pre-commit
18+
rev: v0.4.7
19+
hooks:
20+
- id: ruff
21+
args: ['--fix', '--line-length=79']
22+
23+
- repo: https://github.com/PyCQA/flake8
24+
rev: 7.0.0
25+
hooks:
26+
- id: flake8
27+
28+
- repo: https://github.com/pre-commit/mirrors-mypy
29+
rev: v1.10.0
30+
hooks:
31+
- id: mypy
32+
args: ["--ignore-missing-imports"]

app.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
app = FastAPI()
88

9+
910
class TimeFrameEnum(str, Enum):
1011
next_1_hour = "next_1_hour"
1112
next_4_hour = "next_4_hour"
@@ -16,33 +17,43 @@ class TimeFrameEnum(str, Enum):
1617
next_168_hour = "next_168_hour"
1718

1819

19-
2020
@app.get("/predict/")
2121
def get_predictions(pid: str, timeframe: TimeFrameEnum):
2222
"""Get predictions for the specified `pid` and `timeframe`"""
2323

24-
project_exists_query = f"SELECT pid FROM predictions WHERE pid = '{pid}' LIMIT 1"
24+
project_exists_query = (
25+
f"SELECT pid FROM predictions WHERE pid = '{pid}' LIMIT 1"
26+
)
2527
project = clickhouse_client.execute_query(project_exists_query)
26-
28+
2729
if not project:
2830
raise HTTPException(status_code=404, detail="Project does not exist.")
29-
30-
prediction_query = f"SELECT {timeframe} FROM predictions WHERE pid = '{pid}'"
31+
32+
prediction_query = (
33+
f"SELECT {timeframe} FROM predictions WHERE pid = '{pid}'"
34+
)
3135
result = clickhouse_client.execute_query(prediction_query)
32-
36+
3337
if not result or not result[0][0]:
34-
raise HTTPException(status_code=404, detail="Data not found. Prediction is not available.")
35-
38+
raise HTTPException(
39+
status_code=404,
40+
detail="Data not found. Prediction is not available.",
41+
)
42+
3643
prediction_data = json.loads(result[0][0])
37-
44+
3845
if not prediction_data:
39-
raise HTTPException(status_code=404, detail="Data not found. Prediction is not available.")
40-
46+
raise HTTPException(
47+
status_code=404,
48+
detail="Data not found. Prediction is not available.",
49+
)
50+
4151
return {timeframe: prediction_data}
42-
52+
4353

4454
if __name__ == "__main__":
4555
import uvicorn
56+
4657
uvicorn.run(app, host="0.0.0.0", port=8000)
4758

4859

@@ -52,6 +63,7 @@ def trigger_training():
5263
run_training_module.delay()
5364
return {"message": "Training module triggered"}
5465

66+
5567
@app.post("/run_prediction/")
5668
def trigger_prediction():
5769
"""Trigger the prediction module via Celery"""

celery_tasks/celery_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
celery_app = Celery(
44
"swetrix-ai-celery",
55
broker="redis://localhost:6379/0",
6-
backend="redis://localhost:6379/0"
6+
backend="redis://localhost:6379/0",
77
)
88

99
celery_app.conf.update(
1010
result_expires=3600,
1111
task_serializer="json",
1212
result_serializer="json",
13-
accept_content=["json"]
13+
accept_content=["json"],
14+
task_soft_time_limit=3600, # 1 hour soft time limit
15+
task_time_limit=3700, # 1 hour 10 minutes hard time limit
1416
)
1517

1618
from celery_tasks.tasks import *
17-
# celery_app.autodiscover_tasks(['celery_tasks'])
19+
20+
# celery_app.autodiscover_tasks(['celery_tasks'])

celery_tasks/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ def run_training_module():
1010

1111
@celery_app.task
1212
def run_prediction_module():
13-
predict()
13+
predict()

clickhouse/client.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,34 @@
22
from dotenv import load_dotenv
33
from clickhouse_driver import Client
44

5+
56
class ClickHouseClient:
67
def __init__(self):
78
load_dotenv()
8-
self.host = os.getenv('CLICKHOUSE_HOST')
9-
self.port = os.getenv('CLICKHOUSE_PORT')
10-
self.user = os.getenv('CLICKHOUSE_USER')
11-
self.password = os.getenv('CLICKHOUSE_PASSWORD')
12-
self.database = os.getenv('CLICKHOUSE_DATABASE')
13-
9+
self.host = os.getenv("CLICKHOUSE_HOST")
10+
self.port = os.getenv("CLICKHOUSE_PORT")
11+
self.user = os.getenv("CLICKHOUSE_USER")
12+
self.password = os.getenv("CLICKHOUSE_PASSWORD")
13+
self.database = os.getenv("CLICKHOUSE_DATABASE")
14+
1415
self.client = Client(
1516
host=self.host,
1617
port=self.port,
1718
user=self.user,
1819
password=self.password,
19-
database=self.database
20+
database=self.database,
2021
)
21-
22+
2223
def execute_query(self, query: str):
2324
return self.client.execute(query)
24-
25+
2526
def insert_data(self, table: str, data: list):
2627
self.client.execute(f"INSERT INTO {table} VALUES", data)
2728

29+
def drop_all_data_from_table(self, table_name: str):
30+
"""Drop all data from the table as we require to store only one record in the meantime"""
31+
query = f"TRUNCATE TABLE {table_name}"
32+
self.execute_query(query)
2833

29-
clickhouse_client = ClickHouseClient()
3034

35+
clickhouse_client = ClickHouseClient()

clickhouse/migrations_tables.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from client import ClickHouseClient
22

3+
34
def create_tables():
45
client = ClickHouseClient()
5-
6+
67
training_tmp_query = """
78
CREATE TABLE IF NOT EXISTS training_tmp (
89
cat_features Array(String),
910
cols Array(String),
1011
next_hrs Array(String),
11-
model String
12+
model String
1213
) ENGINE = MergeTree()
1314
ORDER BY tuple()
1415
"""
15-
16+
1617
predictions_query = """
1718
CREATE TABLE IF NOT EXISTS predictions (
1819
pid String,
@@ -26,12 +27,12 @@ def create_tables():
2627
) ENGINE = MergeTree()
2728
ORDER BY pid
2829
"""
29-
30+
3031
client.execute_query(training_tmp_query)
3132
client.execute_query(predictions_query)
3233

3334

3435
client = ClickHouseClient()
3536

3637
if __name__ == "__main__":
37-
create_tables()
38+
create_tables()

clickhouse/utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,52 @@
11
import base64
22
import pickle
33
import json
4-
from data.serialisation import serialise_predictions, serialise_data_for_clickhouse
4+
from data.serialisation import (
5+
serialise_predictions,
6+
serialise_data_for_clickhouse,
7+
)
58
from clickhouse.client import clickhouse_client
69

710

811
"""
912
Clickhouse does not support the pickled objects yet, and it is a problem.
10-
There is a solution to use `base64` encoding, store the model as a string and then decode it and use as a pickle object
13+
There is a solution to use `base64` encoding, store the model as a string and then decode it and use as a pickle object
1114
12-
Though it is a subject of discussion in the future. I personally prefer to store the model in S3 bucket, but this will require an
13-
additional time for development which we do not have to test the model completely in production.
15+
Though it is a subject of discussion in the future. I personally prefer to store the model in S3 bucket, but this will require an
16+
additional time for development which we do not have, as the priority is to test the model in production.
1417
"""
1518

16-
def serialize_model(file_path):
17-
with open(file_path, 'rb') as f:
18-
pickled_model = f.read()
19-
base64_model = base64.b64encode(pickled_model).decode('utf-8')
19+
20+
def serialize_model(model):
21+
pickled_model = pickle.dumps(model)
22+
base64_model = base64.b64encode(pickled_model).decode("utf-8")
2023
return base64_model
2124

2225

2326
def deserialize_model(base64_model):
24-
pickled_model = base64.b64decode(base64_model.encode('utf-8'))
27+
pickled_model = base64.b64decode(base64_model.encode("utf-8"))
2528
model = pickle.loads(pickled_model)
2629
return model
2730

2831

2932
def fetch_model():
30-
result = clickhouse_client.execute_query("SELECT model FROM training_tmp LIMIT 1")
33+
"""Get the serialized model from the database for predictions"""
34+
result = clickhouse_client.execute_query("SELECT model FROM training_tmp")
3135
if result:
3236
serialized_model = result[0][0]
3337
model = deserialize_model(serialized_model)
3438
return model
3539
else:
3640
print("No model found")
3741
return None
38-
42+
3943

4044
def insert_predictions(predictions):
4145
"""Insert serialised JSON data into the predictions table"""
4246
predictions_data = json.loads(predictions)
4347
processed_data = serialise_predictions(predictions_data)
4448
serialized_data = serialise_data_for_clickhouse(processed_data)
45-
clickhouse_client.insert_data('predictions', serialized_data)
4649

50+
# Drop previous(not relevant) data before the insertion of new predictions
51+
clickhouse_client.drop_all_data_from_table("predictions")
52+
clickhouse_client.insert_data("predictions", serialized_data)

constants.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1-
columns = ("psid","sid","pid","pg","prev","dv","br","os","lc","ref","so","me","ca","cc","rg","ct","sdur","unique","created")
2-
agg_cols = ["year","month","day","day_of_week","hour","pid"]
3-
date_col = 'created'
1+
columns = (
2+
"psid",
3+
"sid",
4+
"pid",
5+
"pg",
6+
"prev",
7+
"dv",
8+
"br",
9+
"os",
10+
"lc",
11+
"ref",
12+
"so",
13+
"me",
14+
"ca",
15+
"cc",
16+
"rg",
17+
"ct",
18+
"sdur",
19+
"unique",
20+
"created",
21+
)
22+
agg_cols = ["year", "month", "day", "day_of_week", "hour", "pid"]
23+
date_col = "created"

0 commit comments

Comments
 (0)