-
Notifications
You must be signed in to change notification settings - Fork 13
/
evaluate.py
112 lines (89 loc) · 3.24 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import math
import os
import pickle
import sys
import pandas as pd
from sklearn import metrics
from sklearn import tree
from dvclive import Live
from matplotlib import pyplot as plt
def evaluate(model, matrix, split, live, save_path):
"""
Dump all evaluation metrics and plots for given datasets.
Args:
model (sklearn.ensemble.RandomForestClassifier): Trained classifier.
matrix (scipy.sparse.csr_matrix): Input matrix.
split (str): Dataset name.
live (dvclive.Live): Dvclive instance.
save_path (str): Path to save the metrics.
"""
labels = matrix[:, 1].toarray().astype(int)
x = matrix[:, 2:]
predictions_by_class = model.predict_proba(x)
predictions = predictions_by_class[:, 1]
# Use dvclive to log a few simple metrics...
avg_prec = metrics.average_precision_score(labels, predictions)
roc_auc = metrics.roc_auc_score(labels, predictions)
if not live.summary:
live.summary = {"avg_prec": {}, "roc_auc": {}}
live.summary["avg_prec"][split] = avg_prec
live.summary["roc_auc"][split] = roc_auc
# ... and plots...
# ... like an roc plot...
live.log_sklearn_plot("roc", labels, predictions, name=f"roc/{split}")
# ... and precision recall plot...
# ... which passes `drop_intermediate=True` to the sklearn method...
live.log_sklearn_plot(
"precision_recall",
labels,
predictions,
name=f"prc/{split}",
drop_intermediate=True,
)
# ... and confusion matrix plot
live.log_sklearn_plot(
"confusion_matrix",
labels.squeeze(),
predictions_by_class.argmax(-1),
name=f"cm/{split}",
)
def save_importance_plot(live, model, feature_names):
"""
Save feature importance plot.
Args:
live (dvclive.Live): DVCLive instance.
model (sklearn.ensemble.RandomForestClassifier): Trained classifier.
feature_names (list): List of feature names.
"""
fig, axes = plt.subplots(dpi=100)
fig.subplots_adjust(bottom=0.2, top=0.95)
axes.set_ylabel("Mean decrease in impurity")
importances = model.feature_importances_
forest_importances = pd.Series(importances, index=feature_names).nlargest(n=30)
forest_importances.plot.bar(ax=axes)
live.log_image("importance.png", fig)
def main():
EVAL_PATH = "eval"
if len(sys.argv) != 3:
sys.stderr.write("Arguments error. Usage:\n")
sys.stderr.write("\tpython evaluate.py model features\n")
sys.exit(1)
model_file = sys.argv[1]
train_file = os.path.join(sys.argv[2], "train.pkl")
test_file = os.path.join(sys.argv[2], "test.pkl")
# Load model and data.
with open(model_file, "rb") as fd:
model = pickle.load(fd)
with open(train_file, "rb") as fd:
train, feature_names = pickle.load(fd)
with open(test_file, "rb") as fd:
test, _ = pickle.load(fd)
# Evaluate train and test datasets.
with Live(EVAL_PATH, dvcyaml=False) as live:
evaluate(model, train, "train", live, save_path=EVAL_PATH)
evaluate(model, test, "test", live, save_path=EVAL_PATH)
# Dump feature importance plot.
save_importance_plot(live, model, feature_names)
if __name__ == "__main__":
main()