Skip to content

Commit 10e5f92

Browse files
author
Domantas
committed
Parameter correction
1 parent 23557f4 commit 10e5f92

File tree

1 file changed

+41
-55
lines changed

1 file changed

+41
-55
lines changed

03_train_models.py

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,51 @@
11
import pandas as pd
22
import numpy as np
3-
import os
43
import ast
54
import pickle
5+
import joblib
66
from datetime import datetime
77

8-
date = datetime.now().strftime("%Y-%M-%d")
8+
date = datetime.now().strftime("%Y-%m-%d")
99
input_path = f'Datasets/Translated_tokens_{date}.csv'
1010
output_path = f'Models/Models_LR_{date}.joblib'
1111
words_path = f'Frequency_models/word_frequency_{date}.picle'
12-
if os.path.isfile(input_path) and os.path.isfile(words_path) and not os.path.isfile(output_path):
1312

14-
df = pd.read_csv(input_path)
15-
pickle_in = open(words_path,"rb")
16-
words_frequency = pickle.load(pickle_in)
17-
# Models creation
18-
top = 2500
19-
from collections import Counter
20-
21-
features = np.zeros(df.shape[0] * top).reshape(df.shape[0], top)
22-
labels = np.zeros(df.shape[0])
23-
counter = 0
24-
for i, row in df.iterrows():
25-
c = [word for word, word_count in Counter(ast.literal_eval(row['tokens_en'])).most_common(top)]
26-
labels[counter] = list(set(df['main_category'].values)).index(row['main_category'])
27-
for word in c:
28-
if word in words_frequency[row['main_category']]:
29-
features[counter][words_frequency[row['main_category']].index(word)] = 1
30-
counter += 1
31-
32-
from sklearn.metrics import accuracy_score
33-
from scipy.sparse import coo_matrix
34-
X_sparse = coo_matrix(features)
35-
36-
from sklearn.utils import shuffle
37-
X, X_sparse, y = shuffle(features, X_sparse, labels, random_state=0)
38-
39-
from sklearn.model_selection import train_test_split
40-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
41-
42-
from sklearn.linear_model import LogisticRegression
43-
lr = LogisticRegression()
44-
lr.fit(X_train, y_train)
45-
lr_predictions = lr.predict(X_test)
46-
score = lr.score(X_test, y_test)
47-
print('LogisticRegression')
48-
print('Score: ', score)
49-
print('Top: ', top)
50-
print('Dataset length: ', df.shape[0])
51-
print()
52-
53-
from sklearn.svm import LinearSVC
54-
clf = LinearSVC()
55-
clf.fit(X_train, y_train)
56-
clf_predictions = clf.predict(X_test)
57-
score = clf.score(X_test, y_test)
58-
print('SVM')
59-
print('Score: ', score)
60-
print('Top: ', top)
61-
print('Dataset length: ', df.shape[0])
62-
63-
# Save models
64-
from sklearn.externals import joblib
65-
joblib.dump(lr, output_path)
13+
df = pd.read_csv(input_path)
14+
pickle_in = open(words_path, "rb")
15+
words_frequency = pickle.load(pickle_in)
16+
# Models creation
17+
top = 20000
18+
from collections import Counter
19+
20+
features = np.zeros(df.shape[0] * top).reshape(df.shape[0], top)
21+
labels = np.zeros(df.shape[0])
22+
counter = 0
23+
24+
print('Generating features')
25+
all_categories = list(df.main_category.unique())
26+
for i, row in df.iterrows():
27+
c = [word for word, word_count in Counter(ast.literal_eval(row['tokens_en'])).most_common(top)]
28+
labels[counter] = all_categories.index(row['main_category'])
29+
for word in c:
30+
if word in words_frequency[row['main_category']]:
31+
features[counter][words_frequency[row['main_category']].index(word)] = 1
32+
counter += 1
33+
print("Features generation done")
34+
35+
from sklearn.model_selection import train_test_split
36+
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.33, random_state=42)
37+
38+
print(len(X_train), len(X_test), len(y_train), len(y_test))
39+
from sklearn.linear_model import LogisticRegression
40+
lr = LogisticRegression()
41+
lr.fit(X_train, y_train)
42+
lr_predictions = lr.predict(X_test)
43+
score = lr.score(X_test, y_test)
44+
print('LogisticRegression')
45+
print('Score: ', score)
46+
print('Top: ', top)
47+
print('Dataset length: ', df.shape[0])
48+
print()
49+
50+
# Save models
51+
joblib.dump(lr, output_path)

0 commit comments

Comments
 (0)