|
| 1 | +import os |
| 2 | +import os.path |
| 3 | +from typing import Tuple, List, Dict |
| 4 | +from warnings import warn |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.utils.data as data |
| 9 | +from PIL import Image |
| 10 | +from sklearn.model_selection import train_test_split |
| 11 | +from torch.utils.data import Subset, Dataset, DataLoader |
| 12 | + |
| 13 | +from dataset import train_val_datasets, pil_loader, accimage_loader, default_loader, map_subset_name, TransformSubset |
| 14 | + |
| 15 | + |
| 16 | +def discover_dataset(dir: str, verbose: bool = True) -> Tuple[List[Tuple[str, str]], Dict[str, List[str]]]: |
| 17 | + images = [] |
| 18 | + subset_map = {} |
| 19 | + root_dir = os.path.expanduser(dir) |
| 20 | + idx = 0 |
| 21 | + |
| 22 | + dirs = ["train", "test"] |
| 23 | + |
| 24 | + for dir in dirs: |
| 25 | + indices = [] |
| 26 | + d = os.path.join(root_dir, dir) |
| 27 | + for im_file in os.listdir(d): |
| 28 | + gt = im_file.split("-")[0] |
| 29 | + im_path = os.path.join(d, im_file) |
| 30 | + item = (im_path, gt) |
| 31 | + images.append(item) |
| 32 | + indices.append(idx) |
| 33 | + idx += 1 |
| 34 | + print(item) |
| 35 | + if verbose: |
| 36 | + print("Subset had {} files in it.".format(len(indices))) |
| 37 | + subset_map[dir] = indices |
| 38 | + |
| 39 | + return images, subset_map |
| 40 | + |
| 41 | + |
| 42 | +class CVL(data.Dataset): |
| 43 | + """A generic data loader where the samples are arranged in this way: :: |
| 44 | +
|
| 45 | + root/train/gt-id-xxt.ext |
| 46 | +
|
| 47 | + root/test/gt-id-xxt.ext |
| 48 | +
|
| 49 | + Args: |
| 50 | + root (string): Root directory path. |
| 51 | + loader (callable): A function to load a sample given its path. |
| 52 | + transform (Dict[str, callable], optional): |
| 53 | + A dict from subset_names to functions which transform the input images. |
| 54 | + target_transform (callable, optional): |
| 55 | + A function/transform that takes in a target and returns a transformed version. |
| 56 | + subset_name_map ('auto' or dict[str, str] or None): |
| 57 | + Either a dict which maps the folder to some chosen subset names (e.g. train, test). |
| 58 | + If 'auto' it will be checked if {train, test} is a substring |
| 59 | + of the subset name and will then be used. Subset names not matching this pattern are not touched. |
| 60 | + e.g: a_train -> train |
| 61 | + 'auto' works for the standard CAR-A and CAR-B datasets. |
| 62 | + If None is given the subset names are not changed. |
| 63 | + train_val_split (float): Ratio at which to perform train_val_split. |
| 64 | + Must be greater 0 and smaller or equal than 1 |
| 65 | + If equal to 1, no split is done. |
| 66 | + If unequal 1, a subset with name 'train' must exist after mapping. |
| 67 | + If it exists, two subsets 'train' and 'val' will be added to this subset. |
| 68 | + 'train' subset is overridden. |
| 69 | +
|
| 70 | + Attributes: |
| 71 | + samples (list): List of (sample path, subset_index) tuples |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__(self, root, loader=default_loader, transform=None, target_transform=None, |
| 75 | + subset_name_map='auto', train_val_split: float = 0.8, verbose: bool = False): |
| 76 | + samples, subset_to_idx = discover_dataset(root, verbose=verbose) |
| 77 | + if len(samples) == 0: |
| 78 | + raise (RuntimeError("Found 0 files in subfolders of: " + root)) |
| 79 | + |
| 80 | + self.root = root |
| 81 | + self.loader = loader |
| 82 | + |
| 83 | + self.samples = samples |
| 84 | + |
| 85 | + self.transform = transform |
| 86 | + self.target_transform = target_transform |
| 87 | + |
| 88 | + self.subsets = self.create_subsets(subset_to_idx, subset_name_map) |
| 89 | + assert 0.0 < train_val_split <= 1.0 |
| 90 | + if train_val_split != 1.0: |
| 91 | + assert 'train' in self.subsets |
| 92 | + self.subsets['train'], self.subsets['val'] = train_val_datasets(self.subsets['train'], train_val_split) |
| 93 | + |
| 94 | + def create_subsets(self, subset_map: Dict[str, List[str]], |
| 95 | + subset_name_map) -> Dict[str, Subset]: |
| 96 | + subsets = {} |
| 97 | + for subset_name, indices in subset_map.items(): |
| 98 | + subset_name = map_subset_name(subset_name, subset_name_map) |
| 99 | + transform = self.transform[subset_name] if self.transform else None |
| 100 | + target_transform = self.target_transform[subset_name] if self.target_transform else None |
| 101 | + subset = TransformSubset(self, indices, transform, target_transform) |
| 102 | + subsets[subset_name] = subset |
| 103 | + return subsets |
| 104 | + |
| 105 | + def __getitem__(self, index: int) -> Tuple[Image.Image, str]: |
| 106 | + """ |
| 107 | + Args: |
| 108 | + index (int): Index |
| 109 | +
|
| 110 | + Returns: |
| 111 | + tuple: (sample, target) where target is class_index of the target class. |
| 112 | + """ |
| 113 | + path, target = self.samples[index] |
| 114 | + sample = self.loader(path) |
| 115 | + return sample, target |
| 116 | + |
| 117 | + def __len__(self) -> int: |
| 118 | + return len(self.samples) |
| 119 | + |
| 120 | + def __repr__(self): |
| 121 | + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
| 122 | + fmt_str += ' Number of total datapoints: {}\n'.format(self.__len__()) |
| 123 | + fmt_str += ' Root Location: {}\n'.format(self.root) |
| 124 | + tmp = ' Transforms (if any): ' |
| 125 | + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 126 | + tmp = ' Target Transforms (if any): ' |
| 127 | + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 128 | + fmt_str += '\n\tSubsets: \n' |
| 129 | + for name, subset in self.subsets.items(): |
| 130 | + fmt_str += '\t\t{}: number of datapoints: {}\n'.format(name, len(subset)) |
| 131 | + return fmt_str |
| 132 | + |
| 133 | + def statistics(self) -> str: |
| 134 | + fmt_str = "Max Width: {}\n".format(max([img.width for img, gt in self])) |
| 135 | + fmt_str += "Max Height: {}\n".format(max([img.height for img, gt in self])) |
| 136 | + fmt_str += "Min Width: {}\n".format(min([img.width for img, gt in self])) |
| 137 | + fmt_str += "Min Height: {}\n".format(min([img.height for img, gt in self])) |
| 138 | + fmt_str += "Avg Width: {}\n".format(sum([img.width for img, gt in self]) / float(len(self))) |
| 139 | + fmt_str += "Avg Height: {}\n".format(sum([img.height for img, gt in self]) / float(len(self))) |
| 140 | + fmt_str += "Avg Aspect: {}\n".format(sum([img.width / img.height for img, gt in self]) / float(len(self))) |
| 141 | + return fmt_str |
| 142 | + |
| 143 | + def mean_and_std(self) -> Tuple[float, float]: |
| 144 | + loader = DataLoader( |
| 145 | + self.subsets['train'], |
| 146 | + batch_size=10, |
| 147 | + num_workers=1, |
| 148 | + shuffle=False |
| 149 | + ) |
| 150 | + mean = torch.full((3,), 0.0) |
| 151 | + std = torch.full((3,), 0.0) |
| 152 | + nb_samples = 0. |
| 153 | + for data, gt in loader: |
| 154 | + batch_samples = data.size(0) |
| 155 | + data = data.view(batch_samples, data.size(1), -1) |
| 156 | + mean += data.mean(2).sum(0) |
| 157 | + std += data.std(2).sum(0) |
| 158 | + nb_samples += batch_samples |
| 159 | + mean /= nb_samples |
| 160 | + std /= nb_samples |
| 161 | + return mean, std |
0 commit comments