-
Notifications
You must be signed in to change notification settings - Fork 1
/
example-sgd-classifier.py
67 lines (41 loc) · 2.02 KB
/
example-sgd-classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model.stochastic_gradient import SGDClassifier
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
def retrieveData():
trainingData = pd.read_csv("training-data.csv", header=None).as_matrix()
testData = pd.read_csv("test-data.csv", header=None).as_matrix()
return trainingData, testData
def separateFeaturesAndCategories(trainingData, testData):
trainingFeatures = trainingData[:, :-1]
trainingCategories = trainingData[:, -1:]
testFeatures = testData[:, :-1]
testCategories = testData[:, -1:]
return trainingFeatures, trainingCategories, testFeatures, testCategories
def scaleData(trainingFeatures, testFeatures):
scaler = StandardScaler()
scaler.fit(trainingFeatures)
scaledTrainingFeatures = scaler.transform(trainingFeatures)
scaledTestFeatures = scaler.transform(testFeatures)
return scaledTrainingFeatures, scaledTestFeatures
def classifyTestSamples(trainingFeatures, trainingCategories, testFeatures):
clf = SGDClassifier()
clf.fit(trainingFeatures, trainingCategories)
predictedCategories = clf.predict(testFeatures)
return predictedCategories
def gatherClassificationMetrics(testCategories, predictedCategories):
accuracy = accuracy_score(testCategories, predictedCategories)
metrics_report = classification_report(testCategories, predictedCategories)
print("Accuracy rate: " + str(round(accuracy, 2)) + "\n")
print(metrics_report)
def main():
trainingData, testData = retrieveData()
trainingFeatures, trainingCategories, testFeatures, testCategories = \
separateFeaturesAndCategories(trainingData, testData)
scaledTrainingFeatures, scaledTestingFeatures = \
scaleData(trainingFeatures, testFeatures)
predictedCategories = \
classifyTestSamples(trainingFeatures, trainingCategories, testFeatures)
gatherClassificationMetrics(testCategories, predictedCategories)
if __name__ == "__main__":
main()