Skip to content

Commit 7649714

Browse files
committed
add support to cvl
1 parent d0cf86e commit 7649714

File tree

4 files changed

+256
-78
lines changed

4 files changed

+256
-78
lines changed

car_dataset.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,7 @@
1010
from sklearn.model_selection import train_test_split
1111
from torch.utils.data import Subset, Dataset, DataLoader
1212

13-
14-
def train_val_datasets(dataset: Dataset, val_split: float = 0.5, shuffle: bool = True) -> Tuple[Dataset, Dataset]:
15-
"""
16-
Splits dataset at specified ratio. E.g. to create train-val split.
17-
:param dataset: the source dataset
18-
:param val_split: the ratio of samples which should be training samples
19-
:param shuffle: shuffle the indices
20-
:return: two data subsets which (train, val)
21-
"""
22-
train_idx, valid_idx = train_test_split(np.arange(len(dataset)), test_size=1 - val_split,
23-
train_size=val_split, shuffle=shuffle)
24-
25-
train_dataset = Subset(dataset, train_idx)
26-
val_dataset = Subset(dataset, valid_idx)
27-
return train_dataset, val_dataset
28-
29-
30-
def pil_loader(path):
31-
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
32-
with open(path, 'rb') as f:
33-
img = Image.open(f)
34-
return img.convert('RGB')
35-
36-
37-
def accimage_loader(path):
38-
import accimage
39-
try:
40-
return accimage.Image(path)
41-
except IOError:
42-
# Potentially a decoding problem, fall back to PIL.Image
43-
return pil_loader(path)
44-
45-
46-
def default_loader(path):
47-
from torchvision import get_image_backend
48-
if get_image_backend() == 'accimage':
49-
return accimage_loader(path)
50-
else:
51-
return pil_loader(path)
13+
from dataset import train_val_datasets, pil_loader, accimage_loader, default_loader, map_subset_name, TransformSubset
5214

5315

5416
def discover_dataset(dir: str, verbose: bool = True) -> Tuple[List[Tuple[str, str]], Dict[str, List[str]]]:
@@ -88,44 +50,6 @@ def discover_dataset(dir: str, verbose: bool = True) -> Tuple[List[Tuple[str, st
8850
return images, subset_map
8951

9052

91-
def map_subset_name(subset, subset_name_map):
92-
if not subset_name_map:
93-
return subset
94-
elif subset_name_map == 'auto':
95-
keys = ["train", "test"]
96-
for key in keys:
97-
if key in subset:
98-
return key
99-
elif subset in subset_name_map:
100-
return subset_name_map[subset]
101-
return subset
102-
103-
104-
class TransformSubset(data.Subset):
105-
"""
106-
Subset of a dataset at specified indices which also supports input transforms.
107-
108-
Arguments:
109-
dataset (Dataset): The whole Dataset
110-
indices (sequence): Indices in the whole set selected for subset
111-
transform (callable): Function which transforms input
112-
target_transform (callable): Function which transforms target
113-
"""
114-
115-
def __init__(self, dataset, indices, transform=None, target_transform=None):
116-
super(TransformSubset, self).__init__(dataset, indices)
117-
self.transform = transform
118-
self.target_transform = target_transform
119-
120-
def __getitem__(self, idx):
121-
sample, target = super(TransformSubset, self).__getitem__(idx)
122-
if self.transform is not None:
123-
sample = self.transform(sample)
124-
if self.target_transform is not None:
125-
target = self.target_transform(target)
126-
return sample, target
127-
128-
12953
class CAR(data.Dataset):
13054
"""A generic data loader where the samples are arranged in this way: ::
13155

cvl_dataset.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import os
2+
import os.path
3+
from typing import Tuple, List, Dict
4+
from warnings import warn
5+
6+
import numpy as np
7+
import torch
8+
import torch.utils.data as data
9+
from PIL import Image
10+
from sklearn.model_selection import train_test_split
11+
from torch.utils.data import Subset, Dataset, DataLoader
12+
13+
from dataset import train_val_datasets, pil_loader, accimage_loader, default_loader, map_subset_name, TransformSubset
14+
15+
16+
def discover_dataset(dir: str, verbose: bool = True) -> Tuple[List[Tuple[str, str]], Dict[str, List[str]]]:
17+
images = []
18+
subset_map = {}
19+
root_dir = os.path.expanduser(dir)
20+
idx = 0
21+
22+
dirs = ["train", "test"]
23+
24+
for dir in dirs:
25+
indices = []
26+
d = os.path.join(root_dir, dir)
27+
for im_file in os.listdir(d):
28+
gt = im_file.split("-")[0]
29+
im_path = os.path.join(d, im_file)
30+
item = (im_path, gt)
31+
images.append(item)
32+
indices.append(idx)
33+
idx += 1
34+
print(item)
35+
if verbose:
36+
print("Subset had {} files in it.".format(len(indices)))
37+
subset_map[dir] = indices
38+
39+
return images, subset_map
40+
41+
42+
class CVL(data.Dataset):
43+
"""A generic data loader where the samples are arranged in this way: ::
44+
45+
root/train/gt-id-xxt.ext
46+
47+
root/test/gt-id-xxt.ext
48+
49+
Args:
50+
root (string): Root directory path.
51+
loader (callable): A function to load a sample given its path.
52+
transform (Dict[str, callable], optional):
53+
A dict from subset_names to functions which transform the input images.
54+
target_transform (callable, optional):
55+
A function/transform that takes in a target and returns a transformed version.
56+
subset_name_map ('auto' or dict[str, str] or None):
57+
Either a dict which maps the folder to some chosen subset names (e.g. train, test).
58+
If 'auto' it will be checked if {train, test} is a substring
59+
of the subset name and will then be used. Subset names not matching this pattern are not touched.
60+
e.g: a_train -> train
61+
'auto' works for the standard CAR-A and CAR-B datasets.
62+
If None is given the subset names are not changed.
63+
train_val_split (float): Ratio at which to perform train_val_split.
64+
Must be greater 0 and smaller or equal than 1
65+
If equal to 1, no split is done.
66+
If unequal 1, a subset with name 'train' must exist after mapping.
67+
If it exists, two subsets 'train' and 'val' will be added to this subset.
68+
'train' subset is overridden.
69+
70+
Attributes:
71+
samples (list): List of (sample path, subset_index) tuples
72+
"""
73+
74+
def __init__(self, root, loader=default_loader, transform=None, target_transform=None,
75+
subset_name_map='auto', train_val_split: float = 0.8, verbose: bool = False):
76+
samples, subset_to_idx = discover_dataset(root, verbose=verbose)
77+
if len(samples) == 0:
78+
raise (RuntimeError("Found 0 files in subfolders of: " + root))
79+
80+
self.root = root
81+
self.loader = loader
82+
83+
self.samples = samples
84+
85+
self.transform = transform
86+
self.target_transform = target_transform
87+
88+
self.subsets = self.create_subsets(subset_to_idx, subset_name_map)
89+
assert 0.0 < train_val_split <= 1.0
90+
if train_val_split != 1.0:
91+
assert 'train' in self.subsets
92+
self.subsets['train'], self.subsets['val'] = train_val_datasets(self.subsets['train'], train_val_split)
93+
94+
def create_subsets(self, subset_map: Dict[str, List[str]],
95+
subset_name_map) -> Dict[str, Subset]:
96+
subsets = {}
97+
for subset_name, indices in subset_map.items():
98+
subset_name = map_subset_name(subset_name, subset_name_map)
99+
transform = self.transform[subset_name] if self.transform else None
100+
target_transform = self.target_transform[subset_name] if self.target_transform else None
101+
subset = TransformSubset(self, indices, transform, target_transform)
102+
subsets[subset_name] = subset
103+
return subsets
104+
105+
def __getitem__(self, index: int) -> Tuple[Image.Image, str]:
106+
"""
107+
Args:
108+
index (int): Index
109+
110+
Returns:
111+
tuple: (sample, target) where target is class_index of the target class.
112+
"""
113+
path, target = self.samples[index]
114+
sample = self.loader(path)
115+
return sample, target
116+
117+
def __len__(self) -> int:
118+
return len(self.samples)
119+
120+
def __repr__(self):
121+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
122+
fmt_str += ' Number of total datapoints: {}\n'.format(self.__len__())
123+
fmt_str += ' Root Location: {}\n'.format(self.root)
124+
tmp = ' Transforms (if any): '
125+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
126+
tmp = ' Target Transforms (if any): '
127+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
128+
fmt_str += '\n\tSubsets: \n'
129+
for name, subset in self.subsets.items():
130+
fmt_str += '\t\t{}: number of datapoints: {}\n'.format(name, len(subset))
131+
return fmt_str
132+
133+
def statistics(self) -> str:
134+
fmt_str = "Max Width: {}\n".format(max([img.width for img, gt in self]))
135+
fmt_str += "Max Height: {}\n".format(max([img.height for img, gt in self]))
136+
fmt_str += "Min Width: {}\n".format(min([img.width for img, gt in self]))
137+
fmt_str += "Min Height: {}\n".format(min([img.height for img, gt in self]))
138+
fmt_str += "Avg Width: {}\n".format(sum([img.width for img, gt in self]) / float(len(self)))
139+
fmt_str += "Avg Height: {}\n".format(sum([img.height for img, gt in self]) / float(len(self)))
140+
fmt_str += "Avg Aspect: {}\n".format(sum([img.width / img.height for img, gt in self]) / float(len(self)))
141+
return fmt_str
142+
143+
def mean_and_std(self) -> Tuple[float, float]:
144+
loader = DataLoader(
145+
self.subsets['train'],
146+
batch_size=10,
147+
num_workers=1,
148+
shuffle=False
149+
)
150+
mean = torch.full((3,), 0.0)
151+
std = torch.full((3,), 0.0)
152+
nb_samples = 0.
153+
for data, gt in loader:
154+
batch_samples = data.size(0)
155+
data = data.view(batch_samples, data.size(1), -1)
156+
mean += data.mean(2).sum(0)
157+
std += data.std(2).sum(0)
158+
nb_samples += batch_samples
159+
mean /= nb_samples
160+
std /= nb_samples
161+
return mean, std

dataset.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import os.path
3+
from typing import Tuple, List, Dict
4+
from warnings import warn
5+
6+
import numpy as np
7+
import torch
8+
import torch.utils.data as data
9+
from PIL import Image
10+
from sklearn.model_selection import train_test_split
11+
from torch.utils.data import Subset, Dataset, DataLoader
12+
13+
14+
def train_val_datasets(dataset: Dataset, val_split: float = 0.5, shuffle: bool = True) -> Tuple[Dataset, Dataset]:
15+
"""
16+
Splits dataset at specified ratio. E.g. to create train-val split.
17+
:param dataset: the source dataset
18+
:param val_split: the ratio of samples which should be training samples
19+
:param shuffle: shuffle the indices
20+
:return: two data subsets which (train, val)
21+
"""
22+
train_idx, valid_idx = train_test_split(np.arange(len(dataset)), test_size=1 - val_split,
23+
train_size=val_split, shuffle=shuffle)
24+
25+
train_dataset = Subset(dataset, train_idx)
26+
val_dataset = Subset(dataset, valid_idx)
27+
return train_dataset, val_dataset
28+
29+
30+
def pil_loader(path):
31+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
32+
with open(path, 'rb') as f:
33+
img = Image.open(f)
34+
return img.convert('RGB')
35+
36+
37+
def accimage_loader(path):
38+
import accimage
39+
try:
40+
return accimage.Image(path)
41+
except IOError:
42+
# Potentially a decoding problem, fall back to PIL.Image
43+
return pil_loader(path)
44+
45+
46+
def default_loader(path):
47+
from torchvision import get_image_backend
48+
if get_image_backend() == 'accimage':
49+
return accimage_loader(path)
50+
else:
51+
return pil_loader(path)
52+
53+
54+
def map_subset_name(subset, subset_name_map):
55+
if not subset_name_map:
56+
return subset
57+
elif subset_name_map == 'auto':
58+
keys = ["train", "test"]
59+
for key in keys:
60+
if key in subset:
61+
return key
62+
elif subset in subset_name_map:
63+
return subset_name_map[subset]
64+
return subset
65+
66+
67+
class TransformSubset(data.Subset):
68+
"""
69+
Subset of a dataset at specified indices which also supports input transforms.
70+
71+
Arguments:
72+
dataset (Dataset): The whole Dataset
73+
indices (sequence): Indices in the whole set selected for subset
74+
transform (callable): Function which transforms input
75+
target_transform (callable): Function which transforms target
76+
"""
77+
78+
def __init__(self, dataset, indices, transform=None, target_transform=None):
79+
super(TransformSubset, self).__init__(dataset, indices)
80+
self.transform = transform
81+
self.target_transform = target_transform
82+
83+
def __getitem__(self, idx):
84+
sample, target = super(TransformSubset, self).__getitem__(idx)
85+
if self.transform is not None:
86+
sample = self.transform(sample)
87+
if self.target_transform is not None:
88+
target = self.target_transform(target)
89+
return sample, target

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchvision.transforms import transforms
1414

1515
from car_dataset import CAR
16+
from cvl_dataset import CVL
1617
from model import StringNet
1718
from timer import Timer
1819
from util import concat, length_tensor, format_status_line, write_to_csv
@@ -95,7 +96,10 @@ def create_dataloader(data_path, target_size, train_val_split, batch_size,
9596
}
9697

9798
# Load dataset
98-
dataset = CAR(data_path, transform=data_transforms, train_val_split=train_val_split, verbose=verbose)
99+
if "car" in data_path.lower():
100+
dataset = CAR(data_path, transform=data_transforms, train_val_split=train_val_split, verbose=verbose)
101+
else:
102+
dataset = CVL(data_path, transform=data_transforms, train_val_split=train_val_split, verbose=verbose)
99103
if verbose:
100104
print(dataset)
101105

0 commit comments

Comments
 (0)