-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
119 lines (98 loc) · 5.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from sklearn.utils import resample
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
import os
import re
from bs4 import BeautifulSoup
from collections import Counter
import joblib
def extract_text_from_html(html):
soup = BeautifulSoup(html, 'html.parser')
text = soup.get_text(separator=' ')
# Remove extra whitespaces and newlines
clean_text = re.sub('\s+', ' ', text).strip()
return clean_text
def extract_features(text):
features = {
'contains_greeting': bool(re.search(r'\b(hi|hello|dear|ciao|salve|gentile)\b', text, re.IGNORECASE)),
'contains_signature': bool(re.search(r'\b(best regards|sincerely|thank you|cordiali saluti|distinti saluti|grazie)\b', text, re.IGNORECASE)),
'contains_attachment': bool(re.search(r'\b(attachment|attached|see attached|allegato|allegati|vedi allegato)\b', text, re.IGNORECASE)),
'contains_specific_keyword': bool(re.search(r'\b(refund|confirmation|invoice|receipt|survey|rimborso|conferma|fattura|ricevuta|sondaggio)\b', text, re.IGNORECASE)),
'contains_language_specific_phrases': bool(
re.search(r'\b(automated message|system generated|do not reply|this email was sent automatically|messaggio automatico|generato automaticamente|non rispondere|questa email è stata inviata automaticamente)\b',
text, re.IGNORECASE)),
'contains_urgency_phrases': bool(re.search(r'\b(urgent|as soon as possible|reply by|deadline|time-sensitive|urgente|al più presto|rispondere entro|scadenza|tempo sensibile)\b', text, re.IGNORECASE)),
'contains_customer_specific_info': bool(re.search(r'\b(customer name|account number|order details|membership|nome del cliente|numero di conto|dettagli dell\'ordine|membri)\b', text, re.IGNORECASE))
}
return features
if __name__ == '__main__':
# Load training data
automated_emails_dir = 'data/1. automated'
human_emails_dir = 'data/2. human'
automated_emails = []
for filename in os.listdir(automated_emails_dir):
with open(os.path.join(automated_emails_dir, filename), 'r', encoding='utf-8') as file:
html = file.read()
text = extract_text_from_html(html)
automated_emails.append((text, 'automated'))
human_emails = []
for filename in os.listdir(human_emails_dir):
with open(os.path.join(human_emails_dir, filename), 'r', encoding='utf-8') as file:
html = file.read()
text = extract_text_from_html(html)
human_emails.append((text, 'human'))
# Check the class distribution
class_distribution = Counter([label for _, label in automated_emails + human_emails])
print("Original Class Distribution:", class_distribution)
# Oversample the minority class
max_class_count = max(class_distribution.values())
if class_distribution['human'] > class_distribution['automated']:
automated_emails_resampled = resample(automated_emails, replace=True, n_samples=max_class_count, random_state=42)
balanced_emails = automated_emails_resampled + human_emails
else:
human_emails_resampled = resample(human_emails, replace=True, n_samples=max_class_count, random_state=42)
balanced_emails = automated_emails + human_emails_resampled
# Check the class distribution after oversampling
class_distribution = Counter([label for _, label in balanced_emails])
print("Class Distribution after Oversampling:", class_distribution)
# Prepare training and test data
balanced_texts = [text for text, _ in balanced_emails]
balanced_labels = [label for _, label in balanced_emails]
text_train, text_test, label_train, label_test = train_test_split(balanced_texts, balanced_labels, test_size=0.2, stratify=balanced_labels, random_state=42)
# Create a pipeline with TF-IDF vectorization and SVM classifier
pipeline = Pipeline([
('tfidf', TfidfVectorizer()),
('svm', SVC(probability=True))
])
# Train the pipeline
pipeline.fit(text_train, label_train)
# Test the pipeline
predicted_probabilities = pipeline.predict_proba(text_test)
# Save the trained model
joblib.dump(pipeline, 'email_classification_model.joblib')
print("Model saved successfully.")
# Get the predicted labels
predicted_labels = pipeline.classes_[predicted_probabilities.argmax(axis=1)]
# Compute evaluation metrics
accuracy = accuracy_score(label_test, predicted_labels)
precision = precision_score(label_test, predicted_labels, pos_label='automated')
recall = recall_score(label_test, predicted_labels, pos_label='automated')
f1 = f1_score(label_test, predicted_labels, pos_label='automated')
classification_report_output = classification_report(label_test, predicted_labels)
# Print evaluation metrics
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-Score: {f1}")
print(f"\nClassification Report:\n{classification_report_output}")
# Print examples, classifications, and probabilities
print("Examples from the Test Set:")
for text, label, predicted_label, probabilities in zip(text_test[:5], label_test[:5], predicted_labels[:5], predicted_probabilities[:5]):
print(f"Text: {text}")
print(f"True Label: {label}")
print(f"Predicted Label: {predicted_label}")
print(f"Probabilities: {round(probabilities[0]*100)}% automated, {round(probabilities[1]*100)}% human")
print("-----------------------")