Skip to content
This repository has been archived by the owner on May 7, 2020. It is now read-only.

mnist and mnistm dataloaders #20

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions dataset_loaders/images/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Mnist dataset."""
import numpy as np
import os
import time
import pickle as pkl
import errno
from PIL import Image
from dataset_loaders.parallel_loader import ThreadedDataset
from timeit import default_timer as timer
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reorder the imports so that they are grouped according to https://www.python.org/dev/peps/pep-0008/#imports and each chunk is alphabetically ordered?



class MnistDataset(ThreadedDataset):
"""The mnist handwritten digit dataset

The dataset should be downloaded from [1] into the `shared_path`
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace [1] with [MNIST]_.

(that should be specified in the config.ini according to the
instructions in ../README.md).

Parameters
----------
which_set: string
A string in ['train', 'val', 'valid', 'test'], corresponding to
the set to be returned.

References
----------
[1] Mnist dataset pickle file:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace [1] with [MNIST]

your_username_here@elisa1.iro.umontreal.ca:/data/lisa/data/mnist/mnist_seg/
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't make this MILA specific. Replace with the MNIST website or the download link.


"""

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the empty line please

name = 'mnist'

# optional arguments
data_shape = (28, 28, 3)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MNIST data_shape should be (28, 28). Can you verify this?

Copy link
Author

@arikanev arikanev Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes sorry.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

mean = [0, 0, 0]
std = [1, 1, 1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to specify mean and std unless you actually compute them. Please either compute them or remove it.

max_files = 50000
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this. It's never used and should be equivalent to self.nsamples


GTclasses = range(2)
mapping_type = 'mnist'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never used. Please remove.


non_void_nclasses = 2
_void_labels = []
GTclasses = range(2)
# GTclasses = GTclasses + [-1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the commented line


_mask_labels = {
0: 'background',
1: 'digit',
}

_cmap = {
0: (0), # background
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cmap should return RGB triples. Please fix this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. To confirm: (0, 0, 0)?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly.

1: (255), # digit
}

_filenames = None

def __init__(self, which_set='train', *args, **kwargs):
"""Construct the ThreadedDataset.

it also creates/copies the dataset in self.path if not already there
mnist data is in 3 directories train, test, valid)
"""

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the empty line please

self.which_set = 'val' if which_set == 'valid' else which_set

# set file paths
if which_set == 'train':
self.image_path = os.path.join(self.path, 'train_images')
self.mask_path = os.path.join(self.path, 'train_masks')

elif which_set == 'test':
self.image_path = os.path.join(self.path, 'test_images')
self.mask_path = os.path.join(self.path, 'test_masks')

else:
self.image_path = os.path.join(self.path, 'val_images')
self.mask_path = os.path.join(self.path, 'val_masks')

super(MnistDataset, self).__init__(*args, **kwargs)

@property
def filenames(self):
"""Get file names for this set."""
if self._filenames is None:
filenames = []

for i in range(len(os.listdir(self.image_path))):
filenames.append(str(i).zfill(5) + '.png')

self._filenames = filenames

print('MnistDataset: ' + self.which_set +
' ' + str(len(filenames)) + ' files')

return self._filenames

def get_names(self):
"""Return a dict of mnist filenames."""

return {'default': self.filenames}

def load_sequence(self, sequence):
"""Load a sequence of images/frames.

Auxiliary function that loads a sequence of mnist images with
the corresponding ground truth mask and the filenames.
Returns a dict with the images in [0, 1], and their filenames.
"""
X = []
Y = []
F = []

for prefix, image in sequence:

# open mnist image, convert to numpy array
curr_mnist_im = Image.open(os.path.join(self.image_path, image))
curr_mnist_im = np.array(curr_mnist_im).astype('float32')

# append image to X
X.append(curr_mnist_im)

# append image fname to F
F.append(image)

# open mnist mask, convert to numpy array
curr_mnist_mask = Image.open(os.path.join(self.mask_path, image))
curr_mnist_mask = np.array(curr_mnist_mask).astype('int32')
curr_mnist_mask = curr_mnist_mask / 255

# append mask to Y
Y.append(curr_mnist_mask)

ret = {}
ret['data'] = np.array(X)
ret['labels'] = np.array(Y)
ret['subset'] = prefix
ret['filenames'] = np.array(F)
return ret


def test():
"""Test."""
trainiter = MnistDataset(
which_set='train',
batch_size=10,
seq_per_subset=0,
seq_length=0,
data_augm_kwargs={'crop_size': (28, 28)},
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=True)

validiter = MnistDataset(
which_set='valid',
batch_size=5,
seq_per_subset=0,
seq_length=0,
data_augm_kwargs={'crop_size': (28, 28)},
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=False)

train_nsamples = trainiter.nsamples
nbatches = trainiter.nbatches
print("Train %d" % (train_nsamples))

valid_nsamples = validiter.nsamples
print("Valid %d" % (valid_nsamples))

# Simulate training
max_epochs = 2
start_training = time.time()
for epoch in range(max_epochs):
start_epoch = time.time()
for mb in range(nbatches):
start_batch = time.time()
trainiter.next()
print("Minibatch {}: {} seg".format(mb, (time.time() -
start_batch)))
print("Epoch time: %s" % str(time.time() - start_epoch))
print("Training time: %s" % str(time.time() - start_training))


def run_tests():
"""Run tests."""
test()


if __name__ == '__main__':
run_tests()
193 changes: 193 additions & 0 deletions dataset_loaders/images/mnistm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Mnist-m dataset."""
import numpy as np
import os
import time
from PIL import Image
from dataset_loaders.parallel_loader import ThreadedDataset


class MnistMDataset(ThreadedDataset):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please inherit from MnistDataset and modify only the relevant parts.

"""The mnist-m handwritten digit dataset

The dataset should be downloaded from [1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the link as done for the other file. Also please add a small description of the dataset.

Copy link
Author

@arikanev arikanev Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. I want to add the download to mnistm via my website akanev.com. Should I wait to push changes until I link it, or should I push a change to just the mnistm download link later on? Or neither and link via google drive.

Copy link
Author

@arikanev arikanev Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will push changes now and then push that later

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any advantage to download through your website rather than the original download link? I think it's best to link to an established website to avoid to potentially generate confusion in the users and to make sure the link is up to date in case the original authors notice some problem with the data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no really established website for mnistm. I believe Yaroslav, (a MILA phd student), created it for previous research. There is one link to the old mnistm data via google drive, which I think could look sketchy. The new mnistm that I generated using a script is not available via the web right now, which is why I wanted to upload it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, theres some errors in the code when inheriting mnistm from mnist that I need to work out. So I will just push mnist for now, and fix the mnistm issues in the meantime.


Parameters
----------
which_set: string
A string in ['train', 'val', 'valid', 'test'], corresponding to
the set to be returned.

References
----------
[1] Mnist-m dataset file:
your_username_here@elisa1.iro.umontreal.ca:/data/lisa/data/mnistm/images/

"""

name = 'mnistm'

# optional arguments
data_shape = (28, 28, 3)
mean = [0, 0, 0]
std = [1, 1, 1]
max_files = 50000

mapping_type = 'mnist'

n_classes = 2

non_void_nclasses = 2
_void_labels = []
GTclasses = range(2)
# GTclasses = GTclasses + [-1]

_mask_labels = {
0: 'background',
1: 'digit',
}

_cmap = {
0: (0), # background
1: (255), # digit
}

_filenames = None

def __init__(self, which_set='train', *args, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is the same as the one in MnistDataset. There is no need to redefine it.
The same goes with all the other methods: remove the parts that do not change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

"""Construct the ThreadedDataset.

it also creates/copies the dataset in self.path if not already there

mnistm_data is in 3 directories: train, test, val;
train contains 50000 images
test contains 9000 images
val contains 9000 images

"""
self.which_set = 'val' if which_set == 'valid' else which_set

# set file paths
if which_set == 'train':
self.im_path = os.path.join(self.path, 'train_images')
self.mask_path = os.path.join(self.path, 'train_masks')
elif which_set == 'test':
self.im_path = os.path.join(self.path, 'test_images')
self.mask_path = os.path.join(self.path, 'test_masks')
else:
self.im_path = os.path.join(self.path, 'val_images')
self.mask_path = os.path.join(self.path, 'val_masks')

super(MnistMDataset, self).__init__(*args, **kwargs)

@property
def filenames(self):
"""Get file names for this set."""
if self._filenames is None:
filenames = []

for i in range(len(os.listdir(self.im_path))):
filenames.append(str(i).zfill(5) + '.png')

print('MnistMDataset: ' + self.which_set +
' ' + str(len(filenames)) + ' files')

return filenames

def get_names(self):
"""Return a dict of mnist filenames."""

return {'default': self.filenames}

def load_sequence(self, sequence):
"""Load a sequence of images/frames.

Auxiliary function that loads a sequence of mnist images with
the corresponding ground truth mask and the filenames.
Returns a dict with the images in [0, 1], and their filenames.
"""
X = []
Y = []
F = []

for prefix, image in sequence:

# open mnist image, convert to numpy array
curr_mnistm_im = Image.open(os.path.join(self.im_path, image))
curr_mnistm_im = np.array(curr_mnistm_im).astype('float32')

# append image to X
X.append(curr_mnistm_im)

# append image fname to F
F.append(image)

# open mnist mask, convert to numpy array
curr_mnistm_mask = Image.open(os.path.join(self.mask_path, image))
curr_mnistm_mask = np.array(curr_mnistm_mask).astype('int32')
curr_mnistm_mask = curr_mnistm_mask / 255

# append mask to Y
Y.append(curr_mnistm_mask)

ret = {}
ret['data'] = np.array(X)
ret['labels'] = np.array(Y)
ret['subset'] = prefix
ret['filenames'] = np.array(F)
return ret


def test():
"""Test."""
trainiter = MnistMDataset(
which_set='train',
batch_size=10,
seq_per_subset=0,
seq_length=0,
data_augm_kwargs={
'crop_size': (28, 28)},
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=True)

validiter = MnistMDataset(
which_set='valid',
batch_size=5,
seq_per_subset=0,
seq_length=0,
data_augm_kwargs={
'crop_size': (28, 28)},
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=False)

train_nsamples = trainiter.nsamples
nbatches = trainiter.nbatches
print("Train %d" % (train_nsamples))

valid_nsamples = validiter.nsamples
print("Valid %d" % (valid_nsamples))

# Simulate training
max_epochs = 2
start_training = time.time()
for epoch in range(max_epochs):
start_epoch = time.time()
for mb in range(nbatches):
start_batch = time.time()
trainiter.next()
print("Minibatch {}: {} seg".format(mb, (time.time() -
start_batch)))
print("Epoch time: %s" % str(time.time() - start_epoch))
print("Training time: %s" % str(time.time() - start_training))


def run_tests():
"""Run tests."""
test()


if __name__ == '__main__':
run_tests()