Skip to content

Commit aff3df9

Browse files
committed
finish training
1 parent 484168a commit aff3df9

File tree

3 files changed

+53
-29
lines changed

3 files changed

+53
-29
lines changed

clickhouse/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
22
import pickle
3-
from client import clickhouse_client
3+
from clickhouse.client import clickhouse_client
44

55

66
"""

data/load_data.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
22
import numpy as np
33
from constants import columns, date_col, agg_cols
4+
from clickhouse.client import clickhouse_client
5+
46
import warnings
57
warnings.filterwarnings("ignore")
68

@@ -171,14 +173,14 @@ def set_target_columns(df):
171173
def remove_date_col(df):
172174
"""Drop the date columun"""
173175
df.drop([date_col],axis=1,inplace=True)
174-
176+
return df
175177

176178
def get_cols_withohut_pid(df):
177179
"""Return the columns EXCEPT OF PID"""
178180
cols=df.drop("pid",axis=1).columns
179-
return cols
181+
return df, cols
180182

181-
def create_target_traffic_by_target_columns(target_columns): # TODO take the data from the db
183+
def create_target_traffic_by_target_columns(df, target_columns):
182184
"""Extract the traffic for next hours"""
183185
next_hrs = []
184186
for hr in [1,4,8,12,24,72,168]: # TODO dynamical value, as well dynamical for the database
@@ -188,25 +190,26 @@ def create_target_traffic_by_target_columns(target_columns): # TODO take the
188190
return next_hrs
189191

190192

191-
#Pre-processing
192-
df = read_data_csv()
193-
df = sort_df_by_date_col(date_col, df)
194-
df = convert_df_to_datetime(df)
195-
df = filter_df_by_specific_date(df, time_delta_years=1)
196-
df = filter_df_with_most_frequent_pid(df)
197-
df = replace_null_values(df)
198-
df, cat_features = categorize_features(df)
199-
df = extract_date_components(df, date_col)
200-
df = add_traffic_table(df)
201-
df = convert_cat_features_to_dummies(df, cat_features)
202-
df = combine_all_pids(df, date_col, agg_cols)
203-
204-
# Setting data fro predictions
205-
target_columns = set_target_columns(df)
206-
remove_date_col(df)
207-
cols = get_cols_withohut_pid(df)
208-
next_hrs = create_target_traffic_by_target_columns(target_columns)
209-
210-
# Clear N/A
211-
df= df.dropna()
212-
193+
def pre_process_data():
194+
#Pre-processing
195+
df = read_data_csv()
196+
df = sort_df_by_date_col(date_col, df)
197+
df = convert_df_to_datetime(df)
198+
df = filter_df_by_specific_date(df, time_delta_years=1)
199+
df = filter_df_with_most_frequent_pid(df)
200+
df = replace_null_values(df)
201+
df, cat_features = categorize_features(df)
202+
df = extract_date_components(df, date_col)
203+
df = add_traffic_table(df)
204+
df = convert_cat_features_to_dummies(df, cat_features)
205+
df = combine_all_pids(df, date_col, agg_cols)
206+
207+
# Setting data fro predictions
208+
target_columns = set_target_columns(df)
209+
df = remove_date_col(df)
210+
df, cols = get_cols_withohut_pid(df)
211+
next_hrs = create_target_traffic_by_target_columns(df, target_columns)
212+
213+
# Clear N/A
214+
df= df.dropna()
215+
return df, cat_features, cols, next_hrs

scripts/run_training.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,32 @@
33
import os
44
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
55

6-
from data.load_data import df, cols, next_hrs
6+
from clickhouse.client import clickhouse_client
7+
from clickhouse.utils import serialize_model
8+
from data.load_data import pre_process_data
79
from models.train_model import train_model, save_model
810
from datetime import datetime
911

1012
def train():
11-
"""Celery task which is called for a model training"""
13+
"""Celery task which is called for a model training
14+
- Gets data from ``load_data`` module
15+
- Saves model to ``.pkl`` format
16+
- Serialises model to ``base64`` encoding
17+
- Inserts data to the DB
18+
"""
19+
timestamp = datetime.now()
20+
file_path = f'pickled_model_{timestamp}'
21+
22+
df, cat_features, cols, next_hrs = pre_process_data()
1223
model = train_model(df,cols, next_hrs)
13-
save_model(f'pickled_model_{datetime.now()}', model)
24+
save_model(file_path, model)
25+
26+
print("cat_features:", cat_features)
27+
print("cols:", cols)
28+
print("next_hrs:", next_hrs)
29+
30+
serialized_model = serialize_model(file_path)
31+
training_tmp_data = [(cat_features, cols.to_list(), next_hrs, serialized_model)]
32+
33+
clickhouse_client.insert_data('training_tmp', training_tmp_data)
34+
print("Training has been completed and data is inserted to the database!")

0 commit comments

Comments
 (0)