Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

Commit

Permalink
Merge pull request #25 from DeepRank/class
Browse files Browse the repository at this point in the history
fix classification tasks
  • Loading branch information
manonreau authored Dec 23, 2020
2 parents 83713bf + 695c768 commit 486baa1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
7 changes: 4 additions & 3 deletions graphprot/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

def get_boolean(values, threshold, target):

if target == 'fnat' or target == 'bin':
inverse = ['fnat', 'bin']
if target in inverse:
values_bool = [1 if x > threshold else 0 for x in values]
else:
values_bool = [1 if x < threshold else 0 for x in values]
Expand Down Expand Up @@ -120,13 +121,13 @@ def hitrate(self):

idx = np.argsort(self.prediction)

if self.target == 'fnat' or self.target == 'bin':
inverse = ['fnat', 'bin']
if self.target in inverse:
idx = idx[::-1]

ground_truth_bool = get_boolean(
self.y, self.threshold, self.target)
ground_truth_bool = np.array(ground_truth_bool)

hitrate = np.cumsum(ground_truth_bool[idx])

return hitrate
37 changes: 25 additions & 12 deletions graphprot/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(self, database, Net,
0).num_features).to(self.device)

elif self.task == 'class':
self.classes = classes
self.classes_idx = {i: idx for idx,
i in enumerate(self.classes)}
self.output_shape = len(self.classes)
Expand Down Expand Up @@ -335,19 +334,21 @@ def print_epoch_data(stage, epoch, loss, acc, time):
print('Epoch [%04d] : %s loss %e | accuracy %s | time %1.2e sec.' % (epoch,
stage, loss, acc_str, time))

def format_output(self, out, target):
def format_output(self, pred, target):
"""Format the network output depending on the task (classification/regression)."""

if self.task == 'class':
out = F.softmax(out, dim=1)

if self.task == 'class' :
pred = F.softmax(pred, dim=1)
target = torch.tensor(
[self.classes_idx[int(x)] for x in target])

else:
out = out.reshape(-1)

return out, target
pred = pred.reshape(-1)
return pred, target


def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):
"""Test the model
Expand Down Expand Up @@ -380,6 +381,7 @@ def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):

self.test_out = _out
self.test_y = _y

_test_acc = self.get_metrics('test', threshold).accuracy
self.test_acc = _test_acc
self.test_loss = _test_loss
Expand Down Expand Up @@ -412,17 +414,23 @@ def eval(self, loader):

y += d.y
loss_val += loss_func(pred, d.y).detach().item()
out += pred.reshape(-1).tolist()

# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
if self.task == 'class':
pred = np.argmax(pred.detach(), axis=1)
else:
pred = pred.detach().reshape(-1)

out += pred
data['targets'] += d.y.numpy().tolist()
data['outputs'] += pred.tolist()

# get the data
data['mol'] += d['mol']

return out, y, loss_val, data


def _epoch(self, epoch):
"""Run a single epoch
Expand All @@ -446,13 +454,18 @@ def _epoch(self, epoch):
loss = self.loss(pred, d.y)
running_loss += loss.detach().item()
loss.backward()
out += pred.reshape(-1).tolist()
self.optimizer.step()

# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
data['targets'] += d.y.numpy().tolist()
if self.task == 'class':
pred = np.argmax(pred.detach(), axis=1)
else:
pred = pred.detach().reshape(-1)

out += pred
data['targets'] += d.y.numpy().tolist()
data['outputs'] += pred.tolist()

# get the data
data['mol'] += d['mol']

Expand Down

0 comments on commit 486baa1

Please sign in to comment.