Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

Commit

Permalink
Merge pull request #61 from DeepRank/fix_train_loss_bug
Browse files Browse the repository at this point in the history
fix small typo
  • Loading branch information
manonreau authored Nov 10, 2021
2 parents 00df4c1 + c026e34 commit 2f88c5f
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions deeprank_gnn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, database, Net,
f" task='class',"
f" shuffle=True,"
f" percent=[0.8, 0.2])")

if self.task == 'class' and self.threshold == None:
print('the threshold for accuracy computation is set to {}'.format(self.classes[1]))
self.threshold = self.classes[1]
Expand Down Expand Up @@ -212,7 +212,7 @@ def put_model_to_device(self, dataset, Net):
print(torch.cuda.get_device_name(0))

self.num_edge_features = len(self.edge_feature)

# regression mode
if self.task == 'reg':
self.model = Net(dataset.get(
Expand Down Expand Up @@ -279,17 +279,17 @@ def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hd

# Open output file for writting
with h5py.File(fname, 'w') as self.f5:

# Number of epochs
self.nepoch = nepoch

# Loop over epochs
self.data = {}
for epoch in range(1, nepoch+1):

# Train the model
self.model.train()

t0 = time()
_out, _y, _loss, self.data['train'] = self._epoch(epoch)
t = time() - t0
Expand All @@ -298,7 +298,7 @@ def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hd
self.train_y = _y
_acc = self.get_metrics('train', self.threshold).accuracy
self.train_acc.append(_acc)

# Print the loss and accuracy (training set)
self.print_epoch_data(
'train', epoch, _loss, _acc, t)
Expand Down Expand Up @@ -332,7 +332,7 @@ def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hd
else:
# if no validation set, saves the best performing model on the traing set
if save_model == 'best':
if min(self.train_loss) == _train_loss:
if min(self.train_loss) == _loss:
print(
'WARNING: The training set is used both for learning and model selection.')
print(
Expand All @@ -345,7 +345,7 @@ def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hd
# Save epoch data
if (save_epoch == 'all') or (epoch == nepoch):
self._export_epoch_hdf5(epoch, self.data)

elif (save_epoch == 'intermediate') and (epoch % save_every == 0):
self._export_epoch_hdf5(epoch, self.data)

Expand All @@ -354,7 +354,7 @@ def train(self, nepoch=1, validate=False, save_model='last', hdf5='train_data.hd
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr)))


def test(self, database_test=None, threshold=4, hdf5='test_data.hdf5'):
"""
Tests the model
Expand Down Expand Up @@ -396,9 +396,9 @@ def test(self, database_test=None, threshold=4, hdf5='test_data.hdf5'):
# Run test
_out, _y, _test_loss, self.data['test'] = self.eval(
self.test_loader)

self.test_out = _out

if len(_y) == 0:
self.test_y = None
self.test_acc = None
Expand All @@ -409,7 +409,7 @@ def test(self, database_test=None, threshold=4, hdf5='test_data.hdf5'):

self.test_loss = _test_loss
self._export_epoch_hdf5(0, self.data)


def eval(self, loader):
"""
Expand Down

0 comments on commit 2f88c5f

Please sign in to comment.