-
Notifications
You must be signed in to change notification settings - Fork 63
Isbi em stacks crossvalidation argument #8
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution!!
Can you please make sure your code respects the PEP8 coding style rules?
The easiest way to do so is to install pep8 with pip install pep8
(with --local
if you don't want to install it globally) and run pep8 sbi_em_stacks.py.py
to check the code. Once you are done please commit the changes and I'll review the PR.
Thanks!
I already applied the coding style modifications. |
c6a8d70
to
ff0bbfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR and sorry for keeping you waiting for so long!
Here is my review, there are a few modifications that I think will make the code easier to understand. Once you implement them we can merge it!
Thanks!
whereas 15\% will be used for validation. | ||
For example, if split=0.85, 85\% of the images will be used | ||
for training, whereas 15\% will be used for validation. | ||
crossval: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int or None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please rename as crossval_nfolds
for training, whereas 15\% will be used for validation. | ||
crossval: int | ||
If it is set to None, to cross-validation is used. An int specifying | ||
in how many folds we want to split our data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When None cross-validation is disabled. Else, represents the number of folds we data will be split into.
@@ -52,20 +58,52 @@ class IsbiEmStacksDataset(ThreadedDataset): | |||
1: (255, 255, 255)} # Membranes | |||
_mask_labels = {0: 'Non-membranes', 1: 'Membranes'} | |||
|
|||
def __init__(self, which_set='train', split=0.85, *args, **kwargs): | |||
def __init__(self, which_set='train', split=0.60, crossval=5, fold=3, rand_perm=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description of rand_perm
is missing.
in how many folds we want to split our data. | ||
fold: int | ||
An int specifying which fold we want. If fold=1, images from 0 to 5 | ||
will be used as validation. If fold=2, images from 6 to 11, and so on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fold
should be zero-based. Please change it to:
An int specifying which fold to use for validation. If fold=0, images from 0 to 5 will be used for validation. If fold=1, images from 6 to 11, and so on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to valid_fold
elif self.which_set == "test": | ||
self.start = 0 | ||
self.end = 30 | ||
self.middle_fold = False # False by default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the comment, the code is already clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove self.middle_fold altogether (see comment below)
img_per_fold = int(30/crossval) | ||
# start and end index for validation fold | ||
start = (fold-1)*img_per_fold | ||
end = fold*img_per_fold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it (fold + 1) * img_per_fold
Edit: disregard, see comment below.
For example, if split=0.85, 85\% of the images will be used for training, | ||
whereas 15\% will be used for validation. | ||
For example, if split=0.85, 85\% of the images will be used | ||
for training, whereas 15\% will be used for validation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add "Will be ignored if crossval_nfolds is not None"
elif self.which_set == "val": | ||
self.start = start | ||
self.end = end | ||
elif self.which_set == "test": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think 'test' makes sense in the cross-validation case. I suggest to raise a ValueError if which_set is 'test' here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More generally, I suggest to replace L75-93 with the following approach:
if self.which_set == "train":
self.start_1 = 0
self.end_1 = fold * img_per_fold
self.start_2 = (fold+1) * img_per_fold
self.end_2 = 30
elif self.which_set == "val":
self.start_1 = fold * img_per_fold
self.end_1 = self.start_2 = self.end_2 = (fold+1) * img_per_fold
elif self.which_set == "test":
raise ValueError('Cannot perform cross-validation on test.')
and then replace self.get_names with
return {'default': self.rand_indexes[self.start_1:self.end_1] +
self.rand_indexes[self.start_2:self.end_2]}
if rand_perm is not None: | ||
self.rand_indexes=rand_perm | ||
else: | ||
self.rand_indexes=range(0,30) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.rand_indexes = range(30)
# if validation is a middle fold, concatenate separated train folds | ||
return {'default': self.rand_indexes[range(0, self.start)+range(self.end, 30)].tolist()} | ||
else: | ||
return {'default': self.rand_indexes[range(self.start, self.end)].tolist()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with the suggested command (see comment above)
Hello @ArantxaCasanova! Thanks for updating the PR. Cheers ! There are no PEP8 issues in this Pull Request. 🍻 Comment last updated on July 05, 2017 at 16:58 Hours UTC |
Thanks for the revision! I corrected what you suggested. I still had to use "self.crossval = True" to control the "get_names()" function without making it too messy. Let me know what you think! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for fixing the code and please excuse me for the embarrassingly long time it took me to review your new commit.
There are only two minor things I'd like you to fix, and then we can merge! Thank you very much for your contribution :)
return {'default': ( | ||
self.rand_indices[self.start_1:self.end_1] | ||
).tolist() + ( | ||
self.rand_indices[self.start_2:self.end_2]).tolist()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indentation makes it difficult to read the code.
This should respect the PEP8 while maintaining the readability of the code:
return {'default':
self.rand_indices[self.start_1:self.end_1]).tolist() + (
self.rand_indices[self.start_2:self.end_2]).tolist()}
self.middle_fold = False # False by default | ||
if crossval is not None: # if cross-validation is used | ||
if crossval_nfolds is not None: | ||
self.crossval = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer not to add this attribute to the class since the dataset objects are already a bit cluttered :)
I suggest to remove it and check for hasattr(self, 'rand_indices')
in get_names()
if self.middle_fold: | ||
# if validation is a middle fold, concatenate separated train folds | ||
return {'default': self.rand_indexes[range(0, self.start)+range(self.end, 30)].tolist()} | ||
if self.crossval: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above
To be able to separate de data in folds and select a specific one for validation.