-
Notifications
You must be signed in to change notification settings - Fork 32
add export options #18
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great addition ! One small remark about opening the h5 file with 'w'. It automatically erase any file with that name which can be dangerous. To avoid future issue where we accidentally erase data we want to keep I would change that. Otherwise very good !!
""" | ||
|
||
# Output file | ||
fname = os.path.join(self.outdir, hdf5) | ||
self.f5 = h5py.File(fname, 'w') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by default this will erase the file fname if it exists which can be dangerous if you rerun an experiment but want to keep the previous results. I would add a check to see if the file exists and if it does change the name with a number. train_data_001.hdf5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ! If train_data_001.hdf5 exists, then I change it to train_data_002.hdf5 and so on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great !
graphprot/NeuralNet.py
Outdated
if (save_epoch == 'all') or (epoch == nepoch) : | ||
self._export_epoch_hdf5(epoch, self.data) | ||
|
||
elif (save_epoch == 'intermediate') and (epoch%5 == 0) : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could add save_every=n
as named argument of the function so that we can decide how often we save
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, defaults is set to 5
def _export_epoch_hdf5(self, epoch, data): | ||
"""Export the epoch data to the hdf5 file. | ||
Export the data of a given epoch in train/valid/test group. | ||
In each group are stored the predcited values (outputs), | ||
ground truth (targets) and molecule name (mol). | ||
Args: | ||
epoch (int): index of the epoch | ||
data (dict): data of the epoch | ||
""" | ||
|
||
# create a group | ||
grp_name = 'epoch_%04d' % epoch | ||
grp = self.f5.create_group(grp_name) | ||
|
||
grp.attrs['task'] = self.task | ||
grp.attrs['target'] = self.target | ||
grp.attrs['batch_size'] = self.batch_size | ||
|
||
# loop over the pass_type : train/valid/test | ||
for pass_type, pass_data in data.items(): | ||
|
||
# we don't want to breack the process in case of issue | ||
try: | ||
|
||
# create subgroup for the pass | ||
sg = grp.create_group(pass_type) | ||
|
||
# loop over the data : target/output/molname | ||
for data_name, data_value in pass_data.items(): | ||
|
||
# mol name is a bit different | ||
# since there are strings | ||
if data_name == 'mol': | ||
string_dt = h5py.special_dtype(vlen=str) | ||
sg.create_dataset( | ||
data_name, data=data_value, dtype=string_dt) | ||
|
||
# output/target values | ||
else: | ||
sg.create_dataset(data_name, data=data_value) | ||
|
||
except TypeError: | ||
logger.exception("Error in export epoch to hdf5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice :) !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took that part from Deeprank;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looked familiar :)
""" | ||
|
||
# Output file | ||
fname = os.path.join(self.outdir, hdf5) | ||
self.f5 = h5py.File(fname, 'w') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great !
def _export_epoch_hdf5(self, epoch, data): | ||
"""Export the epoch data to the hdf5 file. | ||
Export the data of a given epoch in train/valid/test group. | ||
In each group are stored the predcited values (outputs), | ||
ground truth (targets) and molecule name (mol). | ||
Args: | ||
epoch (int): index of the epoch | ||
data (dict): data of the epoch | ||
""" | ||
|
||
# create a group | ||
grp_name = 'epoch_%04d' % epoch | ||
grp = self.f5.create_group(grp_name) | ||
|
||
grp.attrs['task'] = self.task | ||
grp.attrs['target'] = self.target | ||
grp.attrs['batch_size'] = self.batch_size | ||
|
||
# loop over the pass_type : train/valid/test | ||
for pass_type, pass_data in data.items(): | ||
|
||
# we don't want to breack the process in case of issue | ||
try: | ||
|
||
# create subgroup for the pass | ||
sg = grp.create_group(pass_type) | ||
|
||
# loop over the data : target/output/molname | ||
for data_name, data_value in pass_data.items(): | ||
|
||
# mol name is a bit different | ||
# since there are strings | ||
if data_name == 'mol': | ||
string_dt = h5py.special_dtype(vlen=str) | ||
sg.create_dataset( | ||
data_name, data=data_value, dtype=string_dt) | ||
|
||
# output/target values | ||
else: | ||
sg.create_dataset(data_name, data=data_value) | ||
|
||
except TypeError: | ||
logger.exception("Error in export epoch to hdf5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looked familiar :)
Add different export options:
Export training, evaluation and test data in HDF5 (mol name, prediction, y)
Chose data export frequency (all or intermediate (every 5 epochs for now))
Export the best or the last model