Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

Commit

Permalink
Merge pull request #23 from DeepRank/fix_test_pretrained
Browse files Browse the repository at this point in the history
fix test and pretrained model loading
  • Loading branch information
manonreau authored Dec 11, 2020
2 parents 6f5181b + 8438ddf commit 83713bf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 43 deletions.
74 changes: 37 additions & 37 deletions graphprot/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,46 @@
from sklearn.metrics import confusion_matrix


def get_boolean(vals, threshold, target):
def get_boolean(values, threshold, target):

if target == 'fnat' or target == 'bin':
vals_bool = [1 if x > threshold else 0 for x in vals]
values_bool = [1 if x > threshold else 0 for x in values]
else:
vals_bool = [1 if x < threshold else 0 for x in vals]
values_bool = [1 if x < threshold else 0 for x in values]

return vals_bool
return values_bool


def get_comparison(prediction, ground_truth, binary=True, classes=[0, 1]):

CM = confusion_matrix(ground_truth, prediction, labels=classes)

FP = CM.sum(axis=0) - np.diag(CM)
FN = CM.sum(axis=1) - np.diag(CM)
TP = np.diag(CM)
TN = CM.sum() - (FP + FN + TP)
false_positive = CM.sum(axis=0) - np.diag(CM)
false_negative = CM.sum(axis=1) - np.diag(CM)
true_positive = np.diag(CM)
true_negative = CM.sum() - (false_positive + false_negative + true_positive)

if binary == True:
return FP[1], FN[1], TP[1], TN[1]
return false_positive[1], false_negative[1], true_positive[1], true_negative[1]

else:
return FP, FN, TP, TN
return false_positive, false_negative, true_positive, true_negative


class Metrics(object):

def __init__(self, y_pred, y_hat, target, threshold=4, binary=True):
def __init__(self, prediction, y, target, threshold=4, binary=True):
'''Master class from which all the other metrics are computed
Arguments
y_pred: predicted values
y_hat: ground truth
prediction: predicted values
y: ground truth
target: irmsd, fnat, class, bin
threshold: threshold used to split the data into a binary vector
binary: transform the data in binary vectors
'''

self.y_pred = y_pred
self.y_hat = y_hat
self.prediction = prediction
self.y = y
self.binary = binary
self.target = target
self.threshold = threshold
Expand All @@ -54,77 +54,77 @@ def __init__(self, y_pred, y_hat, target, threshold=4, binary=True):

if self.binary == True:

y_pred_CM = get_boolean(
self.y_pred, self.threshold, self.target)
y_hat_CM = get_boolean(
self.y_hat, self.threshold, self.target)
prediction_bool = get_boolean(
self.prediction, self.threshold, self.target)
y_bool = get_boolean(
self.y, self.threshold, self.target)
classes = [0, 1]

FP, FN, TP, TN = get_comparison(
y_pred_CM, y_hat_CM, self.binary, classes=classes)
false_positive, false_negative, true_positive, true_negative = get_comparison(
prediction_bool, y_bool, self.binary, classes=classes)

else:
if self.target == 'class':
classes = [1, 2, 3, 4, 5]
else:
classes = [0, 1]

FP, FN, TP, TN = get_comparison(
self.y_pred, self.y_hat, self.binary, classes=classes)
false_positive, false_negative, true_positive, true_negative = get_comparison(
self.prediction, self.y, self.binary, classes=classes)

try:
# Sensitivity, hit rate, recall, or true positive rate
self.TPR = TP/(TP+FN)
self.sensitivity = true_positive/(true_positive+false_negative)
except:
self.TPR = None
self.sensitivity = None

try:
# Specificity or true negative rate
self.TNR = TN/(TN+FP)
self.specificity = true_negative/(true_negative+false_positive)
except:
self.TNR = None
self.specificity = None

try:
# Precision or positive predictive value
self.PPV = TP/(TP+FP)
self.precision = true_positive/(true_positive+false_positive)
except:
self.PPV = None
self.precision = None

try:
# Negative predictive value
self.NPV = TN/(TN+FN)
self.NPV = true_negative/(true_negative+false_negative)
except:
self.NPV = None

try:
# Fall out or false positive rate
self.FPR = FP/(FP+TN)
self.FPR = false_positive/(false_positive+true_negative)
except:
self.FPR = None

try:
# False negative rate
self.FNR = FN/(TP+FN)
self.FNR = false_negative/(true_positive+false_negative)
except:
self.FNR = None

try:
# False discovery rate
self.FDR = FP/(TP+FP)
self.FDR = false_positive/(true_positive+false_positive)
except:
self.FDR = None

self.ACC = (TP+TN)/(TP+FP+FN+TN)
self.accuracy = (true_positive+true_negative)/(true_positive+false_positive+false_negative+true_negative)

def HitRate(self):
def hitrate(self):

idx = np.argsort(self.y_pred)
idx = np.argsort(self.prediction)

if self.target == 'fnat' or self.target == 'bin':
idx = idx[::-1]

ground_truth_bool = get_boolean(
self.y_hat, self.threshold, self.target)
self.y, self.threshold, self.target)
ground_truth_bool = np.array(ground_truth_bool)

hitrate = np.cumsum(ground_truth_bool[idx])
Expand Down
11 changes: 5 additions & 6 deletions graphprot/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def plot_hit_rate(self, data='eval', threshold=4, mode='percentage', name=''):

try:

hitrate = self.get_metrics(data, threshold).HitRate()
hitrate = self.get_metrics(data, threshold).hitrate()

nb_models = len(hitrate)
X = range(1, nb_models + 1)
Expand Down Expand Up @@ -258,7 +258,7 @@ def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='t
self.train_loss.append(_loss)
self.train_out = _out
self.train_y = _y
_acc = self.get_metrics('train', self.threshold).ACC
_acc = self.get_metrics('train', self.threshold).accuracy
self.train_acc.append(_acc)

# Print the loss and accuracy (training set)
Expand All @@ -276,7 +276,7 @@ def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='t
self.valid_out = _out
self.valid_y = _y
_val_acc = self.get_metrics(
'eval', self.threshold).ACC
'eval', self.threshold).accuracy
self.valid_acc.append(_val_acc)

# Print loss and accuracy (validation set)
Expand Down Expand Up @@ -338,7 +338,7 @@ def print_epoch_data(stage, epoch, loss, acc, time):
def format_output(self, out, target):
"""Format the network output depending on the task (classification/regression)."""

if self.task == 'class' :
if self.task == 'class':
out = F.softmax(out, dim=1)
target = torch.tensor(
[self.classes_idx[int(x)] for x in target])
Expand Down Expand Up @@ -380,8 +380,7 @@ def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):

self.test_out = _out
self.test_y = _y
print(_out, _y)
_test_acc = self.get_metrics('test', threshold).ACC
_test_acc = self.get_metrics('test', threshold).accuracy
self.test_acc = _test_acc
self.test_loss = _test_loss

Expand Down

0 comments on commit 83713bf

Please sign in to comment.