Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

fix classification tasks #25

Merged
merged 4 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions graphprot/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

def get_boolean(values, threshold, target):

if target == 'fnat' or target == 'bin':
inverse = ['fnat', 'bin']
if target in inverse:
values_bool = [1 if x > threshold else 0 for x in values]
else:
values_bool = [1 if x < threshold else 0 for x in values]
Expand Down Expand Up @@ -120,13 +121,13 @@ def hitrate(self):

idx = np.argsort(self.prediction)

if self.target == 'fnat' or self.target == 'bin':
inverse = ['fnat', 'bin']
if self.target in inverse:
idx = idx[::-1]

ground_truth_bool = get_boolean(
self.y, self.threshold, self.target)
ground_truth_bool = np.array(ground_truth_bool)

hitrate = np.cumsum(ground_truth_bool[idx])

return hitrate
37 changes: 25 additions & 12 deletions graphprot/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(self, database, Net,
0).num_features).to(self.device)

elif self.task == 'class':
self.classes = classes
self.classes_idx = {i: idx for idx,
i in enumerate(self.classes)}
self.output_shape = len(self.classes)
Expand Down Expand Up @@ -335,19 +334,21 @@ def print_epoch_data(stage, epoch, loss, acc, time):
print('Epoch [%04d] : %s loss %e | accuracy %s | time %1.2e sec.' % (epoch,
stage, loss, acc_str, time))

def format_output(self, out, target):
def format_output(self, pred, target):
"""Format the network output depending on the task (classification/regression)."""

if self.task == 'class':
out = F.softmax(out, dim=1)

if self.task == 'class' :
pred = F.softmax(pred, dim=1)
target = torch.tensor(
[self.classes_idx[int(x)] for x in target])

else:
out = out.reshape(-1)

return out, target
pred = pred.reshape(-1)
return pred, target


def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):
"""Test the model

Expand Down Expand Up @@ -380,6 +381,7 @@ def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):

self.test_out = _out
self.test_y = _y

_test_acc = self.get_metrics('test', threshold).accuracy
self.test_acc = _test_acc
self.test_loss = _test_loss
Expand Down Expand Up @@ -412,17 +414,23 @@ def eval(self, loader):

y += d.y
loss_val += loss_func(pred, d.y).detach().item()
out += pred.reshape(-1).tolist()

# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
if self.task == 'class':
pred = np.argmax(pred.detach(), axis=1)
else:
pred = pred.detach().reshape(-1)

out += pred
data['targets'] += d.y.numpy().tolist()
data['outputs'] += pred.tolist()

# get the data
data['mol'] += d['mol']

return out, y, loss_val, data


def _epoch(self, epoch):
"""Run a single epoch

Expand All @@ -446,13 +454,18 @@ def _epoch(self, epoch):
loss = self.loss(pred, d.y)
running_loss += loss.detach().item()
loss.backward()
out += pred.reshape(-1).tolist()
self.optimizer.step()

# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
data['targets'] += d.y.numpy().tolist()
if self.task == 'class':
pred = np.argmax(pred.detach(), axis=1)
else:
pred = pred.detach().reshape(-1)

out += pred
data['targets'] += d.y.numpy().tolist()
data['outputs'] += pred.tolist()

# get the data
data['mol'] += d['mol']

Expand Down