-
Notifications
You must be signed in to change notification settings - Fork 32
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import numpy as np | ||
from tqdm import tqdm | ||
from time import time | ||
import h5py | ||
import os | ||
|
||
# torch import | ||
import torch | ||
|
@@ -22,10 +24,10 @@ class NeuralNet(object): | |
|
||
def __init__(self, database, Net, | ||
node_feature=['type', 'polarity', 'bsa'], | ||
edge_feature=['dist'], target='irmsd', | ||
edge_feature=['dist'], target='irmsd', lr=0.01, | ||
batch_size=32, percent=[0.8, 0.2], index=None, database_eval=None, | ||
class_weights=None, task='class', classes=[0, 1], threshold=4, | ||
pretrained_model=None, shuffle=False): | ||
pretrained_model=None, shuffle=False, outdir='./'): | ||
|
||
# load the input data or a pretrained model | ||
# each named arguments is stored in a member vairable | ||
|
@@ -98,7 +100,7 @@ def __init__(self, database, Net, | |
|
||
# optimizer | ||
self.optimizer = torch.optim.Adam( | ||
self.model.parameters(), lr=0.01) | ||
self.model.parameters(), lr=self.lr) | ||
|
||
# laod the optimizer state if we have one | ||
if pretrained_model is not None: | ||
|
@@ -119,7 +121,7 @@ def __init__(self, database, Net, | |
self.valid_acc = [] | ||
self.valid_loss = [] | ||
|
||
def plot_loss(self): | ||
def plot_loss(self, name=''): | ||
"""Plot the loss of the model.""" | ||
|
||
nepoch = self.nepoch | ||
|
@@ -139,10 +141,10 @@ def plot_loss(self): | |
plt.xlabel("Number of epoch") | ||
plt.ylabel("Total loss") | ||
plt.legend() | ||
plt.savefig('loss_epoch.png') | ||
plt.savefig('loss_epoch{}.png'.format(name)) | ||
plt.close() | ||
|
||
def plot_acc(self): | ||
def plot_acc(self, name=''): | ||
"""Plot the accuracy of the model.""" | ||
|
||
nepoch = self.nepoch | ||
|
@@ -162,10 +164,10 @@ def plot_acc(self): | |
plt.xlabel("Number of epoch") | ||
plt.ylabel("Accuracy") | ||
plt.legend() | ||
plt.savefig('acc_epoch.png') | ||
plt.savefig('acc_epoch{}.png'.format(name)) | ||
plt.close() | ||
|
||
def plot_hit_rate(self, data='eval', threshold=4, mode='percentage'): | ||
def plot_hit_rate(self, data='eval', threshold=4, mode='percentage', name=''): | ||
"""Plots the hitrate as a function of the models' rank | ||
|
||
Args: | ||
|
@@ -191,42 +193,55 @@ def plot_hit_rate(self, data='eval', threshold=4, mode='percentage'): | |
plt.xlabel("Number of models") | ||
plt.ylabel("Hit Rate") | ||
plt.legend() | ||
plt.savefig('hitrate.png') | ||
plt.savefig('hitrate{}.png'.format(name)) | ||
plt.close() | ||
|
||
except: | ||
print('No hit rate plot could be generated for you {} task'.format( | ||
self.task)) | ||
|
||
def train(self, nepoch=1, validate=False, plot=False): | ||
def train(self, nepoch=1, validate=False, plot=False, save_model='last', hdf5='train_data.hdf5', save_epoch='intermediate'): | ||
"""Train the model | ||
|
||
Args: | ||
nepoch (int, optional): number of epochs. Defaults to 1. | ||
validate (bool, optional): perform validation. Defaults to False. | ||
plot (bool, optional): plot the results. Defaults to False. | ||
savemodel (last, best): save the model. Defaults to 'last' | ||
""" | ||
|
||
# Output file | ||
fname = os.path.join(self.outdir, hdf5) | ||
self.f5 = h5py.File(fname, 'w') | ||
|
||
# Number of epochs | ||
self.nepoch = nepoch | ||
|
||
# Loop over epochs | ||
self.data = {} | ||
for epoch in range(1, nepoch+1): | ||
|
||
# Train the model | ||
self.model.train() | ||
|
||
t0 = time() | ||
_out, _y, _loss = self._epoch(epoch) | ||
_out, _y, _loss, self.data['train'] = self._epoch(epoch) | ||
t = time() - t0 | ||
self.train_loss.append(_loss) | ||
self.train_out = _out | ||
self.train_y = _y | ||
_acc = self.get_metrics('train', self.threshold).ACC | ||
self.train_acc.append(_acc) | ||
|
||
self.print_epoch_data('train', epoch, _loss, _acc, t) | ||
# Print the loss and accuracy (training set) | ||
self.print_epoch_data( | ||
'train', epoch, _loss, _acc, t) | ||
|
||
# Validate the model | ||
if validate is True: | ||
|
||
t0 = time() | ||
_out, _y, _val_loss = self.eval(self.valid_loader) | ||
_out, _y, _val_loss, self.data['eval'] = self.eval(self.valid_loader) | ||
t = time() - t0 | ||
|
||
self.valid_loss.append(_val_loss) | ||
|
@@ -236,9 +251,42 @@ def train(self, nepoch=1, validate=False, plot=False): | |
'eval', self.threshold).ACC | ||
self.valid_acc.append(_val_acc) | ||
|
||
# Print loss and accuracy (validation set) | ||
self.print_epoch_data( | ||
'valid', epoch, _val_loss, _val_acc, t) | ||
|
||
# save the best model (i.e. lowest loss value on validation data) | ||
if save_model == 'best' : | ||
|
||
if min(self.valid_loss) == _val_loss : | ||
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format( | ||
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr))) | ||
|
||
else : | ||
# if no validation set, saves the best performing model on the traing set | ||
if save_model == 'best' : | ||
if min(self.train_loss) == _train_loss : | ||
print ('WARNING: The training set is used both for learning and model selection.') | ||
print('this may lead to training set data overfitting.') | ||
print('We advice you to use an external validation set.') | ||
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format( | ||
self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr))) | ||
|
||
# Save epoch data | ||
if (save_epoch == 'all') or (epoch == nepoch) : | ||
self._export_epoch_hdf5(epoch, self.data) | ||
|
||
elif (save_epoch == 'intermediate') and (epoch%5 == 0) : | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, defaults is set to 5 |
||
self._export_epoch_hdf5(epoch, self.data) | ||
|
||
# Save the last model | ||
if save_model == 'last' : | ||
self.save_model(filename='t{}_y{}_b{}_e{}_lr{}.pth.tar'.format(self.task, self.target, str(self.batch_size), str(nepoch), str(self.lr))) | ||
|
||
# Close output file | ||
self.f5.close() | ||
|
||
|
||
@staticmethod | ||
def print_epoch_data(stage, epoch, loss, acc, time): | ||
"""print the data of each epoch | ||
|
@@ -272,14 +320,19 @@ def format_output(self, out, target): | |
|
||
return out, target | ||
|
||
def test(self, database_test, threshold): | ||
def test(self, database_test, threshold=4, hdf5='test_data.hdf5'): | ||
"""Test the model | ||
|
||
Args: | ||
database_test ([type]): [description] | ||
threshold ([type]): [description] | ||
""" | ||
|
||
# Output file | ||
fname = os.path.join(self.outdir, hdf5) | ||
self.f5 = h5py.File(fname, 'w') | ||
|
||
# Load the test set | ||
test_dataset = HDF5DataSet(root='./', database=database_test, | ||
node_feature=self.node_feature, edge_feature=self.edge_feature, | ||
target=self.target) | ||
|
@@ -289,14 +342,22 @@ def test(self, database_test, threshold): | |
self.test_loader = DataLoader( | ||
test_dataset) | ||
|
||
_out, _y, _test_loss = self.eval(self.test_loader) | ||
self.data = {} | ||
|
||
# run test | ||
_out, _y, _test_loss, self.data['test'] = self.eval(self.test_loader) | ||
|
||
self.test_out = _out | ||
self.test_y = _y | ||
_test_acc = self.get_metrics('test', threshold).ACC | ||
self.test_acc = _test_acc | ||
self.test_loss = _test_loss | ||
|
||
if save_prediction : | ||
self._export_epoch_hdf5(0, self.data) | ||
|
||
self.f5.close() | ||
|
||
def eval(self, loader): | ||
"""Evaluate the model | ||
|
||
|
@@ -312,16 +373,28 @@ def eval(self, loader): | |
loss_func, loss_val = self.loss, 0 | ||
out = [] | ||
y = [] | ||
for data in loader: | ||
data = data.to(self.device) | ||
pred = self.model(data) | ||
pred, data.y = self.format_output(pred, data.y) | ||
data = {'outputs': [], 'targets': [], 'mol': []} | ||
|
||
y += data.y | ||
loss_val += loss_func(pred, data.y).detach().item() | ||
for d in loader: | ||
d = d.to(self.device) | ||
pred = self.model(d) | ||
pred, d.y = self.format_output(pred, d.y) | ||
|
||
y += d.y | ||
loss_val += loss_func(pred, d.y).detach().item() | ||
out += pred.reshape(-1).tolist() | ||
|
||
return out, y, loss_val | ||
# get the outputs for export | ||
data['outputs'] += pred.reshape(-1).tolist() | ||
data['targets'] += d.y.numpy().tolist() | ||
|
||
# get the data | ||
mol = d['mol'] | ||
fname, molname = mol[0], mol[1] | ||
data['mol'] += [(f, m) for f, m in zip(fname, molname)] | ||
data['mol'] = np.array(data['mol'], dtype=object) | ||
|
||
return out, y, loss_val, data | ||
|
||
def _epoch(self, epoch): | ||
"""Run a single epoch | ||
|
@@ -333,21 +406,33 @@ def _epoch(self, epoch): | |
running_loss = 0 | ||
out = [] | ||
y = [] | ||
for data in self.train_loader: | ||
data = {'outputs': [], 'targets': [], 'mol': []} | ||
|
||
data = data.to(self.device) | ||
for d in self.train_loader: | ||
|
||
d = d.to(self.device) | ||
self.optimizer.zero_grad() | ||
pred = self.model(data) | ||
pred, data.y = self.format_output(pred, data.y) | ||
pred = self.model(d) | ||
pred, d.y = self.format_output(pred, d.y) | ||
|
||
y += data.y | ||
loss = self.loss(pred, data.y) | ||
y += d.y | ||
loss = self.loss(pred, d.y) | ||
running_loss += loss.detach().item() | ||
loss.backward() | ||
out += pred.reshape(-1).tolist() | ||
self.optimizer.step() | ||
|
||
return out, y, running_loss | ||
# get the outputs for export | ||
data['outputs'] += pred.reshape(-1).tolist() | ||
data['targets'] += d.y.numpy().tolist() | ||
|
||
# get the data | ||
mol = d['mol'] | ||
fname, molname = mol[0], mol[1] | ||
data['mol'] += [(f, m) for f, m in zip(fname, molname)] | ||
data['mol'] = np.array(data['mol'], dtype=object) | ||
|
||
return out, y, running_loss, data | ||
|
||
def get_metrics(self, data='eval', threshold=4, binary=True): | ||
"""Compute the metrics needed | ||
|
@@ -384,6 +469,7 @@ def get_metrics(self, data='eval', threshold=4, binary=True): | |
|
||
return Metrics(pred, y, self.target, threshold, binary) | ||
|
||
|
||
def plot_scatter(self): | ||
"""Scatter plot of the results""" | ||
import matplotlib.pyplot as plt | ||
|
@@ -424,6 +510,7 @@ def save_model(self, filename='model.pth.tar'): | |
'class_weight': self.class_weights, | ||
'batch_size': self.batch_size, | ||
'percent': self.percent, | ||
'lr': self.lr, | ||
'index': self.index, | ||
'shuffle': self.shuffle, | ||
'threshold': self.threshold} | ||
|
@@ -447,6 +534,7 @@ def load_params(self, filename): | |
self.target = state['target'] | ||
self.batch_size = state['batch_size'] | ||
self.percent = state['percent'] | ||
self.lr = state['lr'] | ||
self.index = state['index'] | ||
self.class_weights = state['class_weight'] | ||
self.task = state['task'] | ||
|
@@ -455,3 +543,47 @@ def load_params(self, filename): | |
self.shuffle = state['shuffle'] | ||
|
||
self.opt_loaded_state_dict = state['optimizer'] | ||
|
||
def _export_epoch_hdf5(self, epoch, data): | ||
"""Export the epoch data to the hdf5 file. | ||
Export the data of a given epoch in train/valid/test group. | ||
In each group are stored the predcited values (outputs), | ||
ground truth (targets) and molecule name (mol). | ||
Args: | ||
epoch (int): index of the epoch | ||
data (dict): data of the epoch | ||
""" | ||
|
||
# create a group | ||
grp_name = 'epoch_%04d' % epoch | ||
grp = self.f5.create_group(grp_name) | ||
|
||
grp.attrs['task'] = self.task | ||
grp.attrs['target'] = self.target | ||
grp.attrs['batch_size'] = self.batch_size | ||
|
||
# loop over the pass_type : train/valid/test | ||
for pass_type, pass_data in data.items(): | ||
|
||
# we don't want to breack the process in case of issue | ||
try: | ||
|
||
# create subgroup for the pass | ||
sg = grp.create_group(pass_type) | ||
|
||
# loop over the data : target/output/molname | ||
for data_name, data_value in pass_data.items(): | ||
|
||
# mol name is a bit different | ||
# since there are strings | ||
if data_name == 'mol': | ||
string_dt = h5py.special_dtype(vlen=str) | ||
sg.create_dataset( | ||
data_name, data=data_value, dtype=string_dt) | ||
|
||
# output/target values | ||
else: | ||
sg.create_dataset(data_name, data=data_value) | ||
|
||
except TypeError: | ||
logger.exception("Error in export epoch to hdf5") | ||
Comment on lines
+564
to
+606
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice :) ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took that part from Deeprank;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looked familiar :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by default this will erase the file fname if it exists which can be dangerous if you rerun an experiment but want to keep the previous results. I would add a check to see if the file exists and if it does change the name with a number. train_data_001.hdf5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ! If train_data_001.hdf5 exists, then I change it to train_data_002.hdf5 and so on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great !