Skip to content

Commit b2516ec

Browse files
committed
Add IrisModel class for training a Logistic Regression model on the Iris dataset
1 parent 9ab2abf commit b2516ec

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

src/classifier.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
Train an ML model on Iris dataset
3+
"""
4+
import joblib
5+
from sklearn.linear_model import LogisticRegression
6+
import optuna
7+
import json
8+
from sklearn.datasets import load_iris
9+
from sklearn.model_selection import train_test_split
10+
import numpy as np
11+
import git
12+
13+
base_path = git.Repo('.', search_parent_directories=True).working_tree_dir
14+
params_path = base_path + "/best_params/" + "best_hyper_params" + ".json"
15+
model_path = base_path + "/saved_model/" + "saved_model.joblib"
16+
17+
18+
19+
class IrisModel:
20+
"""
21+
Train a Logistic Regression model on the Iris dataset
22+
"""
23+
def __init__(self):
24+
self.model = None
25+
26+
def hyperparameter_tuning(self, X, y):
27+
def objective(trial):
28+
# hyperparameters to tune for Logistic Regression
29+
params = {
30+
'C': trial.suggest_loguniform('C', 0.01, 10),
31+
'penalty': trial.suggest_categorical('penalty', ['l1', 'l2']),
32+
'solver': trial.suggest_categorical('solver', ['liblinear', 'saga'])
33+
}
34+
model = LogisticRegression(**params)
35+
model.fit(X, y)
36+
return model.score(X, y)
37+
38+
study = optuna.create_study(direction='maximize')
39+
study.optimize(objective, n_trials=100)
40+
# export best params as json
41+
self.hypar_params = params_path
42+
with open(self.hypar_params, 'w') as f:
43+
json.dump(study.best_params, f)
44+
self.best_params = study.best_params
45+
46+
47+
def train(self, X: np.ndarray, y: np.ndarray) -> None:
48+
"""
49+
Trains the logistic regression model.
50+
51+
Parameters:
52+
-----------
53+
X : array-like of shape (n_samples, n_features)
54+
The input data.
55+
y : array-like of shape (n_samples,)
56+
The target values.
57+
58+
Returns:
59+
--------
60+
None
61+
"""
62+
# check if best params are available in self.hyparam_params path
63+
self.hypar_params = params_path
64+
try:
65+
with open(self.hypar_params, 'r') as f:
66+
self.best_params = json.load(f)
67+
print("Best params loaded")
68+
except:
69+
print("Best params not found. Hyperparameter tuning...")
70+
self.hyperparameter_tuning(X, y)
71+
self.model = LogisticRegression(**self.best_params)
72+
self.model.fit(X, y)
73+
print("Model trained")
74+
75+
def predict(self, X: np.ndarray) -> np.ndarray:
76+
"""
77+
Predicts the target values for the input data.
78+
79+
Parameters:
80+
-----------
81+
X : array-like of shape (n_samples, n_features)
82+
The input data.
83+
84+
Returns:
85+
--------
86+
y : array-like of shape (n_samples,)
87+
The predicted target values.
88+
"""
89+
return self.model.predict(X)
90+
91+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
92+
"""
93+
Predicts the target probabilities for the input data.
94+
95+
Parameters:
96+
-----------
97+
X : array-like of shape (n_samples, n_features)
98+
The input data.
99+
100+
Returns:
101+
--------
102+
y : array-like of shape (n_samples, n_classes)
103+
"""
104+
return self.model.predict_proba(X)
105+
106+
def save(self):
107+
"""
108+
Saves the model to disk.
109+
"""
110+
path = model_path
111+
joblib.dump(self.model, path)
112+
113+
def load(self):
114+
"""
115+
Loads the model from disk.
116+
"""
117+
self.model = joblib.load(model_path)
118+
119+
120+
if __name__ == "__main__":
121+
# load data
122+
iris = load_iris()
123+
X = iris.data
124+
y = iris.target
125+
126+
# split data
127+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
128+
129+
# train model
130+
model = IrisModel()
131+
model.train(X_train, y_train)
132+
133+
# evaluate model
134+
print("Test score:", model.model.score(X_test, y_test))
135+
136+
#save model
137+
model.save()
138+
139+

0 commit comments

Comments
 (0)