Skip to content
This repository has been archived by the owner on Mar 22, 2024. It is now read-only.

add export options #18

Merged
merged 6 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions graphprot/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def load_one_graph(self, fname, mol):
data.internal_edge_index = internal_edge_index
data.internal_edge_attr = internal_edge_attr

# mol name
data.mol = mol

# cluster
if 'clustering' in grp.keys():
if self.clustering_method in grp['clustering'].keys():
Expand Down
190 changes: 161 additions & 29 deletions graphprot/NeuralNet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from tqdm import tqdm
from time import time
import h5py
import os

# torch import
import torch
Expand All @@ -22,10 +24,10 @@ class NeuralNet(object):

def __init__(self, database, Net,
node_feature=['type', 'polarity', 'bsa'],
edge_feature=['dist'], target='irmsd',
edge_feature=['dist'], target='irmsd', lr=0.01,
batch_size=32, percent=[0.8, 0.2], index=None, database_eval=None,
class_weights=None, task='class', classes=[0, 1], threshold=4,
pretrained_model=None, shuffle=False):
pretrained_model=None, shuffle=False, outdir='./'):

# load the input data or a pretrained model
# each named arguments is stored in a member vairable
Expand Down Expand Up @@ -98,7 +100,7 @@ def __init__(self, database, Net,

# optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=0.01)
self.model.parameters(), lr=self.lr)

# laod the optimizer state if we have one
if pretrained_model is not None:
Expand All @@ -119,7 +121,7 @@ def __init__(self, database, Net,
self.valid_acc = []
self.valid_loss = []

def plot_loss(self):
def plot_loss(self, name=''):
"""Plot the loss of the model."""

nepoch = self.nepoch
Expand All @@ -139,10 +141,10 @@ def plot_loss(self):
plt.xlabel("Number of epoch")
plt.ylabel("Total loss")
plt.legend()
plt.savefig('loss_epoch.png')
plt.savefig('loss_epoch{}.png'.format(name))
plt.close()

def plot_acc(self):
def plot_acc(self, name=''):
"""Plot the accuracy of the model."""

nepoch = self.nepoch
Expand All @@ -162,10 +164,10 @@ def plot_acc(self):
plt.xlabel("Number of epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig('acc_epoch.png')
plt.savefig('acc_epoch{}.png'.format(name))
plt.close()

def plot_hit_rate(self, data='eval', threshold=4, mode='percentage'):
def plot_hit_rate(self, data='eval', threshold=4, mode='percentage', name=''):
"""Plots the hitrate as a function of the models' rank

Args:
Expand All @@ -191,42 +193,55 @@ def plot_hit_rate(self, data='eval', threshold=4, mode='percentage'):
plt.xlabel("Number of models")
plt.ylabel("Hit Rate")
plt.legend()
plt.savefig('hitrate.png')
plt.savefig('hitrate{}.png'.format(name))
plt.close()

except:
print('No hit rate plot could be generated for you {} task'.format(
self.task))

def train(self, nepoch=1, validate=False, plot=False):
def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate'):
"""Train the model

Args:
nepoch (int, optional): number of epochs. Defaults to 1.
validate (bool, optional): perform validation. Defaults to False.
plot (bool, optional): plot the results. Defaults to False.
savemodel (last, best): save the model. Defaults to 'last'
"""

# Output file
fname = os.path.join(self.outdir, hdf5)
self.f5 = h5py.File(fname, 'w')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by default this will erase the file fname if it exists which can be dangerous if you rerun an experiment but want to keep the previous results. I would add a check to see if the file exists and if it does change the name with a number. train_data_001.hdf5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ! If train_data_001.hdf5 exists, then I change it to train_data_002.hdf5 and so on.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great !


# Number of epochs
self.nepoch = nepoch

# Loop over epochs
self.data = {}
for epoch in range(1, nepoch+1):

# Train the model
self.model.train()

t0 = time()
_out, _y, _loss = self._epoch(epoch)
_out, _y, _loss, self.data['train'] = self._epoch(epoch)
t = time() - t0
self.train_loss.append(_loss)
self.train_out = _out
self.train_y = _y
_acc = self.get_metrics('train', self.threshold).ACC
self.train_acc.append(_acc)

self.print_epoch_data('train', epoch, _loss, _acc, t)
# Print the loss and accuracy (training set)
self.print_epoch_data(
'train', epoch, _loss, _acc, t)

# Validate the model
if validate is True:

t0 = time()
_out, _y, _val_loss = self.eval(self.valid_loader)
_out, _y, _val_loss, self.data['eval'] = self.eval(self.valid_loader)
t = time() - t0

self.valid_loss.append(_val_loss)
Expand All @@ -236,9 +251,42 @@ def train(self, nepoch=1, validate=False, plot=False):
'eval', self.threshold).ACC
self.valid_acc.append(_val_acc)

# Print loss and accuracy (validation set)
self.print_epoch_data(
'valid', epoch, _val_loss, _val_acc, t)

# save the best model (i.e. lowest loss value on validation data)
if save_model == 'best' :

if min(self.valid_loss) == _val_loss :
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr)))

else :
# if no validation set, saves the best performing model on the traing set
if save_model == 'best' :
if min(self.train_loss) == _train_loss :
print ('WARNING: The training set is used both for learning and model selection.')
print('this may lead to training set data overfitting.')
print('We advice you to use an external validation set.')
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr)))

# Save epoch data
if (save_epoch == 'all') or (epoch == nepoch) :
self._export_epoch_hdf5(epoch, self.data)

elif (save_epoch == 'intermediate') and (epoch%5 == 0) :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could add save_every=n as named argument of the function so that we can decide how often we save

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, defaults is set to 5

self._export_epoch_hdf5(epoch, self.data)

# Save the last model
if save_model == 'last' :
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr)))

# Close output file
self.f5.close()


@staticmethod
def print_epoch_data(stage, epoch, loss, acc, time):
"""print the data of each epoch
Expand Down Expand Up @@ -272,14 +320,19 @@ def format_output(self, out, target):

return out, target

def test(self, database_test, threshold):
def test(self, database_test, threshold=4, hdf5='test_data.hdf5'):
"""Test the model

Args:
database_test ([type]): [description]
threshold ([type]): [description]
"""

# Output file
fname = os.path.join(self.outdir, hdf5)
self.f5 = h5py.File(fname, 'w')

# Load the test set
test_dataset = HDF5DataSet(root='./', database=database_test,
node_feature=self.node_feature, edge_feature=self.edge_feature,
target=self.target)
Expand All @@ -289,14 +342,22 @@ def test(self, database_test, threshold):
self.test_loader = DataLoader(
test_dataset)

_out, _y, _test_loss = self.eval(self.test_loader)
self.data = {}

# run test
_out, _y, _test_loss, self.data['test'] = self.eval(self.test_loader)

self.test_out = _out
self.test_y = _y
_test_acc = self.get_metrics('test', threshold).ACC
self.test_acc = _test_acc
self.test_loss = _test_loss

if save_prediction :
self._export_epoch_hdf5(0, self.data)

self.f5.close()

def eval(self, loader):
"""Evaluate the model

Expand All @@ -312,16 +373,28 @@ def eval(self, loader):
loss_func, loss_val = self.loss, 0
out = []
y = []
for data in loader:
data = data.to(self.device)
pred = self.model(data)
pred, data.y = self.format_output(pred, data.y)
data = {'outputs': [], 'targets': [], 'mol': []}

y += data.y
loss_val += loss_func(pred, data.y).detach().item()
for d in loader:
d = d.to(self.device)
pred = self.model(d)
pred, d.y = self.format_output(pred, d.y)

y += d.y
loss_val += loss_func(pred, d.y).detach().item()
out += pred.reshape(-1).tolist()

return out, y, loss_val
# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
data['targets'] += d.y.numpy().tolist()

# get the data
mol = d['mol']
fname, molname = mol[0], mol[1]
data['mol'] += [(f, m) for f, m in zip(fname, molname)]
data['mol'] = np.array(data['mol'], dtype=object)

return out, y, loss_val, data

def _epoch(self, epoch):
"""Run a single epoch
Expand All @@ -333,21 +406,33 @@ def _epoch(self, epoch):
running_loss = 0
out = []
y = []
for data in self.train_loader:
data = {'outputs': [], 'targets': [], 'mol': []}

data = data.to(self.device)
for d in self.train_loader:

d = d.to(self.device)
self.optimizer.zero_grad()
pred = self.model(data)
pred, data.y = self.format_output(pred, data.y)
pred = self.model(d)
pred, d.y = self.format_output(pred, d.y)

y += data.y
loss = self.loss(pred, data.y)
y += d.y
loss = self.loss(pred, d.y)
running_loss += loss.detach().item()
loss.backward()
out += pred.reshape(-1).tolist()
self.optimizer.step()

return out, y, running_loss
# get the outputs for export
data['outputs'] += pred.reshape(-1).tolist()
data['targets'] += d.y.numpy().tolist()

# get the data
mol = d['mol']
fname, molname = mol[0], mol[1]
data['mol'] += [(f, m) for f, m in zip(fname, molname)]
data['mol'] = np.array(data['mol'], dtype=object)

return out, y, running_loss, data

def get_metrics(self, data='eval', threshold=4, binary=True):
"""Compute the metrics needed
Expand Down Expand Up @@ -384,6 +469,7 @@ def get_metrics(self, data='eval', threshold=4, binary=True):

return Metrics(pred, y, self.target, threshold, binary)


def plot_scatter(self):
"""Scatter plot of the results"""
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -424,6 +510,7 @@ def save_model(self, filename='model.pth.tar'):
'class_weight': self.class_weights,
'batch_size': self.batch_size,
'percent': self.percent,
'lr': self.lr,
'index': self.index,
'shuffle': self.shuffle,
'threshold': self.threshold}
Expand All @@ -447,6 +534,7 @@ def load_params(self, filename):
self.target = state['target']
self.batch_size = state['batch_size']
self.percent = state['percent']
self.lr = state['lr']
self.index = state['index']
self.class_weights = state['class_weight']
self.task = state['task']
Expand All @@ -455,3 +543,47 @@ def load_params(self, filename):
self.shuffle = state['shuffle']

self.opt_loaded_state_dict = state['optimizer']

def _export_epoch_hdf5(self, epoch, data):
"""Export the epoch data to the hdf5 file.
Export the data of a given epoch in train/valid/test group.
In each group are stored the predcited values (outputs),
ground truth (targets) and molecule name (mol).
Args:
epoch (int): index of the epoch
data (dict): data of the epoch
"""

# create a group
grp_name = 'epoch_%04d' % epoch
grp = self.f5.create_group(grp_name)

grp.attrs['task'] = self.task
grp.attrs['target'] = self.target
grp.attrs['batch_size'] = self.batch_size

# loop over the pass_type : train/valid/test
for pass_type, pass_data in data.items():

# we don't want to breack the process in case of issue
try:

# create subgroup for the pass
sg = grp.create_group(pass_type)

# loop over the data : target/output/molname
for data_name, data_value in pass_data.items():

# mol name is a bit different
# since there are strings
if data_name == 'mol':
string_dt = h5py.special_dtype(vlen=str)
sg.create_dataset(
data_name, data=data_value, dtype=string_dt)

# output/target values
else:
sg.create_dataset(data_name, data=data_value)

except TypeError:
logger.exception("Error in export epoch to hdf5")
Comment on lines +564 to +606
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice :) !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took that part from Deeprank;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looked familiar :)