Skip to content

Commit

Permalink
Fix minor bug in new data provider
Browse files Browse the repository at this point in the history
  • Loading branch information
AntreasAntoniou committed Jun 12, 2018
1 parent 8443a55 commit 81da766
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
49 changes: 24 additions & 25 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def augment_image(image, k, channels):

class MatchingNetworkDatasetParallel(Dataset):
def __init__(self, batch_size, reverse_channels, num_of_gpus, image_height, image_width, image_channels,
train_val_test_split, num_classes_per_set, num_samples_per_class, seed=100,
reset_stored_filepaths=False, labels_as_int=False):
train_val_test_split, num_classes_per_set, num_samples_per_class,
data_path, dataset_name, indexes_of_folders_indicating_class,
seed=100, reset_stored_filepaths=False, labels_as_int=False):
"""
:param batch_size: The batch size to use for the data loader
:param last_training_class_index: The final index for the training set, used to restrict the training set
Expand All @@ -30,6 +31,9 @@ def __init__(self, batch_size, reverse_channels, num_of_gpus, image_height, imag
:param num_of_gpus: Number of gpus to use for training
:param gen_batches: How many batches to use from the validation set for the end of epoch generations
"""
self.data_path = data_path
self.dataset_name = dataset_name
self.indexes_of_folders_indicating_class = indexes_of_folders_indicating_class
self.labels_as_int = labels_as_int
self.train_val_test_split = train_val_test_split
self.current_dataset_name = "train"
Expand Down Expand Up @@ -171,7 +175,11 @@ def get_label_from_index(self, index):
return index_to_label_name[index]

def get_label_from_path(self, filepath):
raise NotImplementedError
label_bits = filepath.split("/")
label = "_".join([label_bits[idx] for idx in self.indexes_of_folders_indicating_class])
if self.labels_as_int:
label = int(label)
return label

def load_image(self, image_path, channels):

Expand Down Expand Up @@ -418,36 +426,27 @@ def sample_iter_data(self, sample, num_gpus, batch_size, samples_per_iter):

class FolderMatchingNetworkDatasetParallel(MatchingNetworkDatasetParallel):
def __init__(self, name, num_of_gpus, batch_size, image_height, image_width, image_channels,
train_val_test_split, data_path, index_of_folder_indicating_class, reset_stored_filepaths,
train_val_test_split, data_path, indexes_of_folders_indicating_class, reset_stored_filepaths,
num_samples_per_class, num_classes_per_set, labels_as_int, reverse_channels):

self.data_path = os.path.abspath(data_path)
self.dataset_name = name
self.indeces_of_folders_indicating_class_list = index_of_folder_indicating_class

super(FolderMatchingNetworkDatasetParallel, self).__init__(
batch_size=batch_size, reverse_channels=reverse_channels,
num_of_gpus=num_of_gpus, image_height=image_height,
image_width=image_width, image_channels=image_channels,
train_val_test_split=train_val_test_split, reset_stored_filepaths=reset_stored_filepaths,
num_classes_per_set=num_classes_per_set, num_samples_per_class=num_samples_per_class,
labels_as_int=labels_as_int)
labels_as_int=labels_as_int, data_path=os.path.abspath(data_path), dataset_name=name,
indexes_of_folders_indicating_class=indexes_of_folders_indicating_class)

def get_label_from_path(self, filepath):
label_bits = filepath.split("/")
label = "_".join([label_bits[idx] for idx in self.indeces_of_folders_indicating_class_list])
if self.labels_as_int:
label = int(label)
return label

class FolderDatasetLoader(MatchingNetworkLoader):
def __init__(self, name, batch_size, image_height, image_width, image_channels, data_path, train_val_test_split,
num_of_gpus=1, samples_per_iter=1, num_workers=4, index_of_folder_indicating_class=[-2, -3],
num_of_gpus=1, samples_per_iter=1, num_workers=4, indexes_of_folders_indicating_class=[-2],
reset_stored_filepaths=False, num_samples_per_class=1, num_classes_per_set=20, reverse_channels=False,
seed=100, label_as_int=False):

self.name = name
self.index_of_folder_indicating_class = index_of_folder_indicating_class
self.indexes_of_folders_indicating_class = indexes_of_folders_indicating_class
self.reset_stored_filepaths = reset_stored_filepaths
super(FolderDatasetLoader, self).__init__(name, num_of_gpus, batch_size, image_height, image_width, image_channels, num_classes_per_set, data_path,
num_samples_per_class, train_val_test_split,
Expand All @@ -457,11 +456,11 @@ def get_dataset(self, batch_size, reverse_channels, num_of_gpus, image_height, i
train_val_test_split, num_classes_per_set, num_samples_per_class, seed,
reset_stored_filepaths, data_path, labels_as_int):
return FolderMatchingNetworkDatasetParallel(name=self.name, num_of_gpus=num_of_gpus, batch_size=batch_size,
image_height=image_height, image_width=image_width,
image_channels=image_channels,
train_val_test_split=train_val_test_split, data_path=data_path,
index_of_folder_indicating_class=self.index_of_folder_indicating_class,
reset_stored_filepaths=self.reset_stored_filepaths,
num_samples_per_class=num_samples_per_class,
num_classes_per_set=num_classes_per_set, labels_as_int=labels_as_int,
reverse_channels=reverse_channels)
image_height=image_height, image_width=image_width,
image_channels=image_channels,
train_val_test_split=train_val_test_split, data_path=data_path,
indexes_of_folders_indicating_class=self.indexes_of_folders_indicating_class,
reset_stored_filepaths=self.reset_stored_filepaths,
num_samples_per_class=num_samples_per_class,
num_classes_per_set=num_classes_per_set, labels_as_int=labels_as_int,
reverse_channels=reverse_channels)
2 changes: 1 addition & 1 deletion train_one_shot_learning_matching_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
train_val_test_split=(1200/1622, 211/1622, 211/1622),
samples_per_iter=1, num_workers=4,
data_path="datasets/omniglot_dataset", name="omniglot_dataset",
index_of_folder_indicating_class=[-2, -3], reset_stored_filepaths=True,
indexes_of_folders_indicating_class=[-2, -3], reset_stored_filepaths=True,
num_samples_per_class=args.samples_per_class,
num_classes_per_set=args.classes_per_set, label_as_int=False)

Expand Down

0 comments on commit 81da766

Please sign in to comment.