Skip to content

Commit 5cd5722

Browse files
authored
migrate caltech prototype datasets (#5749)
* migrate caltech prototype datasets * resolve third party dependencies
1 parent 4c9cbab commit 5cd5722

File tree

2 files changed

+101
-57
lines changed

2 files changed

+101
-57
lines changed

test/builtin_dataset_mocks.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ def cifar100(root, config):
370370
return len(train_files if config["split"] == "train" else test_files)
371371

372372

373-
# @register_mock
374-
def caltech101(info, root, config):
373+
@register_mock(configs=[dict()])
374+
def caltech101(root, config):
375375
def create_ann_file(root, name):
376376
import scipy.io
377377

@@ -390,15 +390,17 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
390390
images_root = root / "101_ObjectCategories"
391391
anns_root = root / "Annotations"
392392

393-
ann_category_map = {
394-
"Faces_2": "Faces",
395-
"Faces_3": "Faces_easy",
396-
"Motorbikes_16": "Motorbikes",
397-
"Airplanes_Side_2": "airplanes",
393+
image_category_map = {
394+
"Faces": "Faces_2",
395+
"Faces_easy": "Faces_3",
396+
"Motorbikes": "Motorbikes_16",
397+
"airplanes": "Airplanes_Side_2",
398398
}
399399

400+
categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"]
401+
400402
num_images_per_category = 2
401-
for category in info.categories:
403+
for category in categories:
402404
create_image_folder(
403405
root=images_root,
404406
name=category,
@@ -407,7 +409,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
407409
)
408410
create_ann_folder(
409411
root=anns_root,
410-
name=ann_category_map.get(category, category),
412+
name=image_category_map.get(category, category),
411413
file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat",
412414
num_examples=num_images_per_category,
413415
)
@@ -417,27 +419,34 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
417419

418420
make_tar(root, f"{anns_root.name}.tar", anns_root)
419421

420-
return num_images_per_category * len(info.categories)
422+
return num_images_per_category * len(categories)
421423

422424

423-
# @register_mock
424-
def caltech256(info, root, config):
425+
@register_mock(configs=[dict()])
426+
def caltech256(root, config):
425427
dir = root / "256_ObjectCategories"
426428
num_images_per_category = 2
427429

428-
for idx, category in enumerate(info.categories, 1):
430+
categories = [
431+
(1, "ak47"),
432+
(127, "laptop-101"),
433+
(198, "spider"),
434+
(257, "clutter"),
435+
]
436+
437+
for category_idx, category in categories:
429438
files = create_image_folder(
430439
dir,
431-
name=f"{idx:03d}.{category}",
432-
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
440+
name=f"{category_idx:03d}.{category}",
441+
file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg",
433442
num_examples=num_images_per_category,
434443
)
435444
if category == "spider":
436445
open(files[0].parent / "RENAME2", "w").close()
437446

438447
make_tar(root, f"{dir.name}.tar", dir)
439448

440-
return num_images_per_category * len(info.categories)
449+
return num_images_per_category * len(categories)
441450

442451

443452
@register_mock(configs=combinations_grid(split=("train", "val", "test")))

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pathlib
22
import re
3-
from typing import Any, Dict, List, Tuple, BinaryIO
3+
from typing import Any, Dict, List, Tuple, BinaryIO, Union
44

55
import numpy as np
66
from torchdata.datapipes.iter import (
@@ -9,26 +9,49 @@
99
Filter,
1010
IterKeyZipper,
1111
)
12-
from torchvision.prototype.datasets.utils import (
13-
Dataset,
14-
DatasetConfig,
15-
DatasetInfo,
16-
HttpResource,
17-
OnlineResource,
12+
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
13+
from torchvision.prototype.datasets.utils._internal import (
14+
INFINITE_BUFFER_SIZE,
15+
read_mat,
16+
hint_sharding,
17+
hint_shuffling,
18+
BUILTIN_DIR,
1819
)
19-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
2020
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
2121

22+
from .._api import register_dataset, register_info
2223

23-
class Caltech101(Dataset):
24-
def _make_info(self) -> DatasetInfo:
25-
return DatasetInfo(
26-
"caltech101",
24+
25+
CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories"))
26+
27+
28+
@register_info("caltech101")
29+
def _caltech101_info() -> Dict[str, Any]:
30+
return dict(categories=CALTECH101_CATEGORIES)
31+
32+
33+
@register_dataset("caltech101")
34+
class Caltech101(Dataset2):
35+
"""
36+
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
37+
- **dependencies**:
38+
- <scipy `https://scipy.org/`>_
39+
"""
40+
41+
def __init__(
42+
self,
43+
root: Union[str, pathlib.Path],
44+
skip_integrity_check: bool = False,
45+
) -> None:
46+
self._categories = _caltech101_info()["categories"]
47+
48+
super().__init__(
49+
root,
2750
dependencies=("scipy",),
28-
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
51+
skip_integrity_check=skip_integrity_check,
2952
)
3053

31-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
54+
def _resources(self) -> List[OnlineResource]:
3255
images = HttpResource(
3356
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
3457
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
@@ -88,7 +111,7 @@ def _prepare_sample(
88111
ann = read_mat(ann_buffer)
89112

90113
return dict(
91-
label=Label.from_category(category, categories=self.categories),
114+
label=Label.from_category(category, categories=self._categories),
92115
image_path=image_path,
93116
image=image,
94117
ann_path=ann_path,
@@ -98,12 +121,7 @@ def _prepare_sample(
98121
contour=_Feature(ann["obj_contour"].T),
99122
)
100123

101-
def _make_datapipe(
102-
self,
103-
resource_dps: List[IterDataPipe],
104-
*,
105-
config: DatasetConfig,
106-
) -> IterDataPipe[Dict[str, Any]]:
124+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
107125
images_dp, anns_dp = resource_dps
108126

109127
images_dp = Filter(images_dp, self._is_not_background_image)
@@ -122,23 +140,42 @@ def _make_datapipe(
122140
)
123141
return Mapper(dp, self._prepare_sample)
124142

125-
def _generate_categories(self, root: pathlib.Path) -> List[str]:
126-
resources = self.resources(self.default_config)
143+
def __len__(self) -> int:
144+
return 8677
145+
146+
def _generate_categories(self) -> List[str]:
147+
resources = self._resources()
127148

128-
dp = resources[0].load(root)
149+
dp = resources[0].load(self._root)
129150
dp = Filter(dp, self._is_not_background_image)
130151

131152
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
132153

133154

134-
class Caltech256(Dataset):
135-
def _make_info(self) -> DatasetInfo:
136-
return DatasetInfo(
137-
"caltech256",
138-
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
139-
)
155+
CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories"))
156+
157+
158+
@register_info("caltech256")
159+
def _caltech256_info() -> Dict[str, Any]:
160+
return dict(categories=CALTECH256_CATEGORIES)
161+
162+
163+
@register_dataset("caltech256")
164+
class Caltech256(Dataset2):
165+
"""
166+
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
167+
"""
140168

141-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
169+
def __init__(
170+
self,
171+
root: Union[str, pathlib.Path],
172+
skip_integrity_check: bool = False,
173+
) -> None:
174+
self._categories = _caltech256_info()["categories"]
175+
176+
super().__init__(root, skip_integrity_check=skip_integrity_check)
177+
178+
def _resources(self) -> List[OnlineResource]:
142179
return [
143180
HttpResource(
144181
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
@@ -156,25 +193,23 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
156193
return dict(
157194
path=path,
158195
image=EncodedImage.from_file(buffer),
159-
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories),
196+
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
160197
)
161198

162-
def _make_datapipe(
163-
self,
164-
resource_dps: List[IterDataPipe],
165-
*,
166-
config: DatasetConfig,
167-
) -> IterDataPipe[Dict[str, Any]]:
199+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
168200
dp = resource_dps[0]
169201
dp = Filter(dp, self._is_not_rogue_file)
170202
dp = hint_shuffling(dp)
171203
dp = hint_sharding(dp)
172204
return Mapper(dp, self._prepare_sample)
173205

174-
def _generate_categories(self, root: pathlib.Path) -> List[str]:
175-
resources = self.resources(self.default_config)
206+
def __len__(self) -> int:
207+
return 30607
208+
209+
def _generate_categories(self) -> List[str]:
210+
resources = self._resources()
176211

177-
dp = resources[0].load(root)
212+
dp = resources[0].load(self._root)
178213
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
179214

180215
return [name.split(".")[1] for name in sorted(dir_names)]

0 commit comments

Comments
 (0)