Skip to content

Commit 204b915

Browse files
committed
Update _Dist/NeuralNetworks/c_BasicNN
1 parent 83069e8 commit 204b915

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

_Dist/NeuralNetworks/c_BasicNN/NNCore.py

+4
Original file line numberDiff line numberDiff line change
@@ -633,5 +633,9 @@ def predict(self, x, get_raw=False, verbose=True):
633633
tensor = "_output" if self.is_regression or get_raw else "_prob_output"
634634
return self._calculate(x, tensor, "Predict", verbose=verbose)
635635

636+
def predict_classes(self, x, get_raw=False, verbose=True):
637+
pred = self.predict(x, get_raw, verbose)
638+
return np.argmax(pred, axis=1)
639+
636640
def evaluate(self, x, y, verbose=False):
637641
print("{}: {:8.6}".format(self.metric_name, self.metric(y, self.predict(x, verbose=verbose))))

_Dist/NeuralNetworks/c_BasicNN/NNUtils.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
import tensorflow as tf
3+
4+
from scipy import interp
35
from sklearn import metrics
46

57

@@ -28,6 +30,8 @@ class Metrics:
2830
"mse": -1, "ber": -1,
2931
"log_loss": -1
3032
}
33+
require_prob = {name: False for name in sign_dict}
34+
require_prob["auc"] = True
3135

3236
@staticmethod
3337
def check_shape(y, binary=False):
@@ -54,6 +58,19 @@ def auc(y, pred):
5458
Metrics.check_shape(pred, True)
5559
)
5660

61+
@staticmethod
62+
def multi_auc(y, pred):
63+
n_classes = y.shape[1]
64+
fpr, tpr = [None] * n_classes, [None] * n_classes
65+
for i in range(n_classes):
66+
fpr[i], tpr[i], _ = metrics.roc_curve(y[:, i], pred[:, i])
67+
new_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
68+
new_tpr = np.zeros_like(new_fpr)
69+
for i in range(n_classes):
70+
new_tpr += interp(new_fpr, fpr[i], tpr[i])
71+
new_tpr /= n_classes
72+
return metrics.auc(new_fpr, new_tpr)
73+
5774
@staticmethod
5875
def acc(y, pred):
5976
return np.mean(Metrics.check_shape(y) == Metrics.check_shape(pred))

0 commit comments

Comments
 (0)