1
1
import pathlib
2
2
import re
3
- from typing import Any , Dict , List , Tuple , BinaryIO
3
+ from typing import Any , Dict , List , Tuple , BinaryIO , Union
4
4
5
5
import numpy as np
6
6
from torchdata .datapipes .iter import (
9
9
Filter ,
10
10
IterKeyZipper ,
11
11
)
12
- from torchvision .prototype .datasets .utils import (
13
- Dataset ,
14
- DatasetConfig ,
15
- DatasetInfo ,
16
- HttpResource ,
17
- OnlineResource ,
12
+ from torchvision .prototype .datasets .utils import Dataset2 , DatasetInfo , HttpResource , OnlineResource
13
+ from torchvision .prototype .datasets .utils ._internal import (
14
+ INFINITE_BUFFER_SIZE ,
15
+ read_mat ,
16
+ hint_sharding ,
17
+ hint_shuffling ,
18
+ BUILTIN_DIR ,
18
19
)
19
- from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE , read_mat , hint_sharding , hint_shuffling
20
20
from torchvision .prototype .features import Label , BoundingBox , _Feature , EncodedImage
21
21
22
+ from .._api import register_dataset , register_info
22
23
23
- class Caltech101 (Dataset ):
24
- def _make_info (self ) -> DatasetInfo :
25
- return DatasetInfo (
26
- "caltech101" ,
24
+
25
+ CALTECH101_CATEGORIES , * _ = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / "caltech101.categories" ))
26
+
27
+
28
+ @register_info ("caltech101" )
29
+ def _caltech101_info () -> Dict [str , Any ]:
30
+ return dict (categories = CALTECH101_CATEGORIES )
31
+
32
+
33
+ @register_dataset ("caltech101" )
34
+ class Caltech101 (Dataset2 ):
35
+ """
36
+ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
37
+ - **dependencies**:
38
+ - <scipy `https://scipy.org/`>_
39
+ """
40
+
41
+ def __init__ (
42
+ self ,
43
+ root : Union [str , pathlib .Path ],
44
+ skip_integrity_check : bool = False ,
45
+ ) -> None :
46
+ self ._categories = _caltech101_info ()["categories" ]
47
+
48
+ super ().__init__ (
49
+ root ,
27
50
dependencies = ("scipy" ,),
28
- homepage = "http://www.vision.caltech.edu/Image_Datasets/Caltech101" ,
51
+ skip_integrity_check = skip_integrity_check ,
29
52
)
30
53
31
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
54
+ def _resources (self ) -> List [OnlineResource ]:
32
55
images = HttpResource (
33
56
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz" ,
34
57
sha256 = "af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926" ,
@@ -88,7 +111,7 @@ def _prepare_sample(
88
111
ann = read_mat (ann_buffer )
89
112
90
113
return dict (
91
- label = Label .from_category (category , categories = self .categories ),
114
+ label = Label .from_category (category , categories = self ._categories ),
92
115
image_path = image_path ,
93
116
image = image ,
94
117
ann_path = ann_path ,
@@ -98,12 +121,7 @@ def _prepare_sample(
98
121
contour = _Feature (ann ["obj_contour" ].T ),
99
122
)
100
123
101
- def _make_datapipe (
102
- self ,
103
- resource_dps : List [IterDataPipe ],
104
- * ,
105
- config : DatasetConfig ,
106
- ) -> IterDataPipe [Dict [str , Any ]]:
124
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
107
125
images_dp , anns_dp = resource_dps
108
126
109
127
images_dp = Filter (images_dp , self ._is_not_background_image )
@@ -122,23 +140,42 @@ def _make_datapipe(
122
140
)
123
141
return Mapper (dp , self ._prepare_sample )
124
142
125
- def _generate_categories (self , root : pathlib .Path ) -> List [str ]:
126
- resources = self .resources (self .default_config )
143
+ def __len__ (self ) -> int :
144
+ return 8677
145
+
146
+ def _generate_categories (self ) -> List [str ]:
147
+ resources = self ._resources ()
127
148
128
- dp = resources [0 ].load (root )
149
+ dp = resources [0 ].load (self . _root )
129
150
dp = Filter (dp , self ._is_not_background_image )
130
151
131
152
return sorted ({pathlib .Path (path ).parent .name for path , _ in dp })
132
153
133
154
134
- class Caltech256 (Dataset ):
135
- def _make_info (self ) -> DatasetInfo :
136
- return DatasetInfo (
137
- "caltech256" ,
138
- homepage = "http://www.vision.caltech.edu/Image_Datasets/Caltech256" ,
139
- )
155
+ CALTECH256_CATEGORIES , * _ = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / "caltech256.categories" ))
156
+
157
+
158
+ @register_info ("caltech256" )
159
+ def _caltech256_info () -> Dict [str , Any ]:
160
+ return dict (categories = CALTECH256_CATEGORIES )
161
+
162
+
163
+ @register_dataset ("caltech256" )
164
+ class Caltech256 (Dataset2 ):
165
+ """
166
+ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
167
+ """
140
168
141
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
169
+ def __init__ (
170
+ self ,
171
+ root : Union [str , pathlib .Path ],
172
+ skip_integrity_check : bool = False ,
173
+ ) -> None :
174
+ self ._categories = _caltech256_info ()["categories" ]
175
+
176
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
177
+
178
+ def _resources (self ) -> List [OnlineResource ]:
142
179
return [
143
180
HttpResource (
144
181
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar" ,
@@ -156,25 +193,23 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
156
193
return dict (
157
194
path = path ,
158
195
image = EncodedImage .from_file (buffer ),
159
- label = Label (int (pathlib .Path (path ).parent .name .split ("." , 1 )[0 ]) - 1 , categories = self .categories ),
196
+ label = Label (int (pathlib .Path (path ).parent .name .split ("." , 1 )[0 ]) - 1 , categories = self ._categories ),
160
197
)
161
198
162
- def _make_datapipe (
163
- self ,
164
- resource_dps : List [IterDataPipe ],
165
- * ,
166
- config : DatasetConfig ,
167
- ) -> IterDataPipe [Dict [str , Any ]]:
199
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
168
200
dp = resource_dps [0 ]
169
201
dp = Filter (dp , self ._is_not_rogue_file )
170
202
dp = hint_shuffling (dp )
171
203
dp = hint_sharding (dp )
172
204
return Mapper (dp , self ._prepare_sample )
173
205
174
- def _generate_categories (self , root : pathlib .Path ) -> List [str ]:
175
- resources = self .resources (self .default_config )
206
+ def __len__ (self ) -> int :
207
+ return 30607
208
+
209
+ def _generate_categories (self ) -> List [str ]:
210
+ resources = self ._resources ()
176
211
177
- dp = resources [0 ].load (root )
212
+ dp = resources [0 ].load (self . _root )
178
213
dir_names = {pathlib .Path (path ).parent .name for path , _ in dp }
179
214
180
215
return [name .split ("." )[1 ] for name in sorted (dir_names )]
0 commit comments