-
Notifications
You must be signed in to change notification settings - Fork 63
mnist and mnistm dataloaders #20
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
"""Mnist dataset.""" | ||
import numpy as np | ||
import os | ||
import time | ||
import pickle as pkl | ||
import errno | ||
from PIL import Image | ||
from dataset_loaders.parallel_loader import ThreadedDataset | ||
from timeit import default_timer as timer | ||
|
||
|
||
class MnistDataset(ThreadedDataset): | ||
"""The mnist handwritten digit dataset | ||
|
||
The dataset should be downloaded from [1] into the `shared_path` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace |
||
(that should be specified in the config.ini according to the | ||
instructions in ../README.md). | ||
|
||
Parameters | ||
---------- | ||
which_set: string | ||
A string in ['train', 'val', 'valid', 'test'], corresponding to | ||
the set to be returned. | ||
|
||
References | ||
---------- | ||
[1] Mnist dataset pickle file: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace |
||
your_username_here@elisa1.iro.umontreal.ca:/data/lisa/data/mnist/mnist_seg/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't make this MILA specific. Replace with the MNIST website or the download link. |
||
|
||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the empty line please |
||
name = 'mnist' | ||
|
||
# optional arguments | ||
data_shape = (28, 28, 3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe MNIST data_shape should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes sorry. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :) |
||
mean = [0, 0, 0] | ||
std = [1, 1, 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to specify mean and std unless you actually compute them. Please either compute them or remove it. |
||
max_files = 50000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this. It's never used and should be equivalent to |
||
|
||
GTclasses = range(2) | ||
mapping_type = 'mnist' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is never used. Please remove. |
||
|
||
non_void_nclasses = 2 | ||
_void_labels = [] | ||
GTclasses = range(2) | ||
# GTclasses = GTclasses + [-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the commented line |
||
|
||
_mask_labels = { | ||
0: 'background', | ||
1: 'digit', | ||
} | ||
|
||
_cmap = { | ||
0: (0), # background | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cmap should return RGB triples. Please fix this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. To confirm: (0, 0, 0)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly. |
||
1: (255), # digit | ||
} | ||
|
||
_filenames = None | ||
|
||
def __init__(self, which_set='train', *args, **kwargs): | ||
"""Construct the ThreadedDataset. | ||
|
||
it also creates/copies the dataset in self.path if not already there | ||
mnist data is in 3 directories train, test, valid) | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the empty line please |
||
self.which_set = 'val' if which_set == 'valid' else which_set | ||
|
||
# set file paths | ||
if which_set == 'train': | ||
self.image_path = os.path.join(self.path, 'train_images') | ||
self.mask_path = os.path.join(self.path, 'train_masks') | ||
|
||
elif which_set == 'test': | ||
self.image_path = os.path.join(self.path, 'test_images') | ||
self.mask_path = os.path.join(self.path, 'test_masks') | ||
|
||
else: | ||
self.image_path = os.path.join(self.path, 'val_images') | ||
self.mask_path = os.path.join(self.path, 'val_masks') | ||
|
||
super(MnistDataset, self).__init__(*args, **kwargs) | ||
|
||
@property | ||
def filenames(self): | ||
"""Get file names for this set.""" | ||
if self._filenames is None: | ||
filenames = [] | ||
|
||
for i in range(len(os.listdir(self.image_path))): | ||
filenames.append(str(i).zfill(5) + '.png') | ||
|
||
self._filenames = filenames | ||
|
||
print('MnistDataset: ' + self.which_set + | ||
' ' + str(len(filenames)) + ' files') | ||
|
||
return self._filenames | ||
|
||
def get_names(self): | ||
"""Return a dict of mnist filenames.""" | ||
|
||
return {'default': self.filenames} | ||
|
||
def load_sequence(self, sequence): | ||
"""Load a sequence of images/frames. | ||
|
||
Auxiliary function that loads a sequence of mnist images with | ||
the corresponding ground truth mask and the filenames. | ||
Returns a dict with the images in [0, 1], and their filenames. | ||
""" | ||
X = [] | ||
Y = [] | ||
F = [] | ||
|
||
for prefix, image in sequence: | ||
|
||
# open mnist image, convert to numpy array | ||
curr_mnist_im = Image.open(os.path.join(self.image_path, image)) | ||
curr_mnist_im = np.array(curr_mnist_im).astype('float32') | ||
|
||
# append image to X | ||
X.append(curr_mnist_im) | ||
|
||
# append image fname to F | ||
F.append(image) | ||
|
||
# open mnist mask, convert to numpy array | ||
curr_mnist_mask = Image.open(os.path.join(self.mask_path, image)) | ||
curr_mnist_mask = np.array(curr_mnist_mask).astype('int32') | ||
curr_mnist_mask = curr_mnist_mask / 255 | ||
|
||
# append mask to Y | ||
Y.append(curr_mnist_mask) | ||
|
||
ret = {} | ||
ret['data'] = np.array(X) | ||
ret['labels'] = np.array(Y) | ||
ret['subset'] = prefix | ||
ret['filenames'] = np.array(F) | ||
return ret | ||
|
||
|
||
def test(): | ||
"""Test.""" | ||
trainiter = MnistDataset( | ||
which_set='train', | ||
batch_size=10, | ||
seq_per_subset=0, | ||
seq_length=0, | ||
data_augm_kwargs={'crop_size': (28, 28)}, | ||
return_one_hot=True, | ||
return_01c=True, | ||
return_list=True, | ||
use_threads=True) | ||
|
||
validiter = MnistDataset( | ||
which_set='valid', | ||
batch_size=5, | ||
seq_per_subset=0, | ||
seq_length=0, | ||
data_augm_kwargs={'crop_size': (28, 28)}, | ||
return_one_hot=True, | ||
return_01c=True, | ||
return_list=True, | ||
use_threads=False) | ||
|
||
train_nsamples = trainiter.nsamples | ||
nbatches = trainiter.nbatches | ||
print("Train %d" % (train_nsamples)) | ||
|
||
valid_nsamples = validiter.nsamples | ||
print("Valid %d" % (valid_nsamples)) | ||
|
||
# Simulate training | ||
max_epochs = 2 | ||
start_training = time.time() | ||
for epoch in range(max_epochs): | ||
start_epoch = time.time() | ||
for mb in range(nbatches): | ||
start_batch = time.time() | ||
trainiter.next() | ||
print("Minibatch {}: {} seg".format(mb, (time.time() - | ||
start_batch))) | ||
print("Epoch time: %s" % str(time.time() - start_epoch)) | ||
print("Training time: %s" % str(time.time() - start_training)) | ||
|
||
|
||
def run_tests(): | ||
"""Run tests.""" | ||
test() | ||
|
||
|
||
if __name__ == '__main__': | ||
run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
"""Mnist-m dataset.""" | ||
import numpy as np | ||
import os | ||
import time | ||
from PIL import Image | ||
from dataset_loaders.parallel_loader import ThreadedDataset | ||
|
||
|
||
class MnistMDataset(ThreadedDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please inherit from |
||
"""The mnist-m handwritten digit dataset | ||
|
||
The dataset should be downloaded from [1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the link as done for the other file. Also please add a small description of the dataset. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do. I want to add the download to mnistm via my website akanev.com. Should I wait to push changes until I link it, or should I push a change to just the mnistm download link later on? Or neither and link via google drive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will push changes now and then push that later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any advantage to download through your website rather than the original download link? I think it's best to link to an established website to avoid to potentially generate confusion in the users and to make sure the link is up to date in case the original authors notice some problem with the data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no really established website for mnistm. I believe Yaroslav, (a MILA phd student), created it for previous research. There is one link to the old mnistm data via google drive, which I think could look sketchy. The new mnistm that I generated using a script is not available via the web right now, which is why I wanted to upload it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, theres some errors in the code when inheriting mnistm from mnist that I need to work out. So I will just push mnist for now, and fix the mnistm issues in the meantime. |
||
|
||
Parameters | ||
---------- | ||
which_set: string | ||
A string in ['train', 'val', 'valid', 'test'], corresponding to | ||
the set to be returned. | ||
|
||
References | ||
---------- | ||
[1] Mnist-m dataset file: | ||
your_username_here@elisa1.iro.umontreal.ca:/data/lisa/data/mnistm/images/ | ||
|
||
""" | ||
|
||
name = 'mnistm' | ||
|
||
# optional arguments | ||
data_shape = (28, 28, 3) | ||
mean = [0, 0, 0] | ||
std = [1, 1, 1] | ||
max_files = 50000 | ||
|
||
mapping_type = 'mnist' | ||
|
||
n_classes = 2 | ||
|
||
non_void_nclasses = 2 | ||
_void_labels = [] | ||
GTclasses = range(2) | ||
# GTclasses = GTclasses + [-1] | ||
|
||
_mask_labels = { | ||
0: 'background', | ||
1: 'digit', | ||
} | ||
|
||
_cmap = { | ||
0: (0), # background | ||
1: (255), # digit | ||
} | ||
|
||
_filenames = None | ||
|
||
def __init__(self, which_set='train', *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is the same as the one in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
"""Construct the ThreadedDataset. | ||
|
||
it also creates/copies the dataset in self.path if not already there | ||
|
||
mnistm_data is in 3 directories: train, test, val; | ||
train contains 50000 images | ||
test contains 9000 images | ||
val contains 9000 images | ||
|
||
""" | ||
self.which_set = 'val' if which_set == 'valid' else which_set | ||
|
||
# set file paths | ||
if which_set == 'train': | ||
self.im_path = os.path.join(self.path, 'train_images') | ||
self.mask_path = os.path.join(self.path, 'train_masks') | ||
elif which_set == 'test': | ||
self.im_path = os.path.join(self.path, 'test_images') | ||
self.mask_path = os.path.join(self.path, 'test_masks') | ||
else: | ||
self.im_path = os.path.join(self.path, 'val_images') | ||
self.mask_path = os.path.join(self.path, 'val_masks') | ||
|
||
super(MnistMDataset, self).__init__(*args, **kwargs) | ||
|
||
@property | ||
def filenames(self): | ||
"""Get file names for this set.""" | ||
if self._filenames is None: | ||
filenames = [] | ||
|
||
for i in range(len(os.listdir(self.im_path))): | ||
filenames.append(str(i).zfill(5) + '.png') | ||
|
||
print('MnistMDataset: ' + self.which_set + | ||
' ' + str(len(filenames)) + ' files') | ||
|
||
return filenames | ||
|
||
def get_names(self): | ||
"""Return a dict of mnist filenames.""" | ||
|
||
return {'default': self.filenames} | ||
|
||
def load_sequence(self, sequence): | ||
"""Load a sequence of images/frames. | ||
|
||
Auxiliary function that loads a sequence of mnist images with | ||
the corresponding ground truth mask and the filenames. | ||
Returns a dict with the images in [0, 1], and their filenames. | ||
""" | ||
X = [] | ||
Y = [] | ||
F = [] | ||
|
||
for prefix, image in sequence: | ||
|
||
# open mnist image, convert to numpy array | ||
curr_mnistm_im = Image.open(os.path.join(self.im_path, image)) | ||
curr_mnistm_im = np.array(curr_mnistm_im).astype('float32') | ||
|
||
# append image to X | ||
X.append(curr_mnistm_im) | ||
|
||
# append image fname to F | ||
F.append(image) | ||
|
||
# open mnist mask, convert to numpy array | ||
curr_mnistm_mask = Image.open(os.path.join(self.mask_path, image)) | ||
curr_mnistm_mask = np.array(curr_mnistm_mask).astype('int32') | ||
curr_mnistm_mask = curr_mnistm_mask / 255 | ||
|
||
# append mask to Y | ||
Y.append(curr_mnistm_mask) | ||
|
||
ret = {} | ||
ret['data'] = np.array(X) | ||
ret['labels'] = np.array(Y) | ||
ret['subset'] = prefix | ||
ret['filenames'] = np.array(F) | ||
return ret | ||
|
||
|
||
def test(): | ||
"""Test.""" | ||
trainiter = MnistMDataset( | ||
which_set='train', | ||
batch_size=10, | ||
seq_per_subset=0, | ||
seq_length=0, | ||
data_augm_kwargs={ | ||
'crop_size': (28, 28)}, | ||
return_one_hot=True, | ||
return_01c=True, | ||
return_list=True, | ||
use_threads=True) | ||
|
||
validiter = MnistMDataset( | ||
which_set='valid', | ||
batch_size=5, | ||
seq_per_subset=0, | ||
seq_length=0, | ||
data_augm_kwargs={ | ||
'crop_size': (28, 28)}, | ||
return_one_hot=True, | ||
return_01c=True, | ||
return_list=True, | ||
use_threads=False) | ||
|
||
train_nsamples = trainiter.nsamples | ||
nbatches = trainiter.nbatches | ||
print("Train %d" % (train_nsamples)) | ||
|
||
valid_nsamples = validiter.nsamples | ||
print("Valid %d" % (valid_nsamples)) | ||
|
||
# Simulate training | ||
max_epochs = 2 | ||
start_training = time.time() | ||
for epoch in range(max_epochs): | ||
start_epoch = time.time() | ||
for mb in range(nbatches): | ||
start_batch = time.time() | ||
trainiter.next() | ||
print("Minibatch {}: {} seg".format(mb, (time.time() - | ||
start_batch))) | ||
print("Epoch time: %s" % str(time.time() - start_epoch)) | ||
print("Training time: %s" % str(time.time() - start_training)) | ||
|
||
|
||
def run_tests(): | ||
"""Run tests.""" | ||
test() | ||
|
||
|
||
if __name__ == '__main__': | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you reorder the imports so that they are grouped according to https://www.python.org/dev/peps/pep-0008/#imports and each chunk is alphabetically ordered?