1
+ from collections import namedtuple
1
2
from functools import partial
2
3
from pathlib import Path
3
- from typing import Any , Dict , List
4
+ from typing import Any , Dict , List , Optional , Tuple
4
5
5
6
from torchdata .datapipes .iter import IterDataPipe , Mapper , Filter , Demultiplexer , IterKeyZipper , JsonParser
6
7
from torchvision .prototype .datasets .utils import (
10
11
ManualDownloadResource ,
11
12
OnlineResource ,
12
13
)
13
- from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
14
+ from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE , hint_sharding , hint_shuffling
14
15
from torchvision .prototype .features import EncodedImage
16
+ from torchvision .prototype .utils ._internal import FrozenMapping
15
17
16
18
17
19
class CityscapesDatasetInfo (DatasetInfo ):
@@ -43,20 +45,66 @@ def __init__(self, **kwargs: Any) -> None:
43
45
)
44
46
45
47
48
+ CityscapesClass = namedtuple (
49
+ "CityscapesClass" ,
50
+ ["name" , "id" , "train_id" , "category" , "category_id" , "has_instances" , "ignore_in_eval" , "color" ],
51
+ )
52
+
53
+
46
54
class Cityscapes (Dataset ):
55
+
56
+ categories_to_details : FrozenMapping = FrozenMapping (
57
+ {
58
+ "unlabeled" : CityscapesClass ("unlabeled" , 0 , 255 , "void" , 0 , False , True , (0 , 0 , 0 )),
59
+ "ego vehicle" : CityscapesClass ("ego vehicle" , 1 , 255 , "void" , 0 , False , True , (0 , 0 , 0 )),
60
+ "rectification border" : CityscapesClass ("rectification border" , 2 , 255 , "void" , 0 , False , True , (0 , 0 , 0 )),
61
+ "out of roi" : CityscapesClass ("out of roi" , 3 , 255 , "void" , 0 , False , True , (0 , 0 , 0 )),
62
+ "static" : CityscapesClass ("static" , 4 , 255 , "void" , 0 , False , True , (0 , 0 , 0 )),
63
+ "dynamic" : CityscapesClass ("dynamic" , 5 , 255 , "void" , 0 , False , True , (111 , 74 , 0 )),
64
+ "ground" : CityscapesClass ("ground" , 6 , 255 , "void" , 0 , False , True , (81 , 0 , 81 )),
65
+ "road" : CityscapesClass ("road" , 7 , 0 , "flat" , 1 , False , False , (128 , 64 , 128 )),
66
+ "sidewalk" : CityscapesClass ("sidewalk" , 8 , 1 , "flat" , 1 , False , False , (244 , 35 , 232 )),
67
+ "parking" : CityscapesClass ("parking" , 9 , 255 , "flat" , 1 , False , True , (250 , 170 , 160 )),
68
+ "rail track" : CityscapesClass ("rail track" , 10 , 255 , "flat" , 1 , False , True , (230 , 150 , 140 )),
69
+ "building" : CityscapesClass ("building" , 11 , 2 , "construction" , 2 , False , False , (70 , 70 , 70 )),
70
+ "wall" : CityscapesClass ("wall" , 12 , 3 , "construction" , 2 , False , False , (102 , 102 , 156 )),
71
+ "fence" : CityscapesClass ("fence" , 13 , 4 , "construction" , 2 , False , False , (190 , 153 , 153 )),
72
+ "guard rail" : CityscapesClass ("guard rail" , 14 , 255 , "construction" , 2 , False , True , (180 , 165 , 180 )),
73
+ "bridge" : CityscapesClass ("bridge" , 15 , 255 , "construction" , 2 , False , True , (150 , 100 , 100 )),
74
+ "tunnel" : CityscapesClass ("tunnel" , 16 , 255 , "construction" , 2 , False , True , (150 , 120 , 90 )),
75
+ "pole" : CityscapesClass ("pole" , 17 , 5 , "object" , 3 , False , False , (153 , 153 , 153 )),
76
+ "polegroup" : CityscapesClass ("polegroup" , 18 , 255 , "object" , 3 , False , True , (153 , 153 , 153 )),
77
+ "traffic light" : CityscapesClass ("traffic light" , 19 , 6 , "object" , 3 , False , False , (250 , 170 , 30 )),
78
+ "traffic sign" : CityscapesClass ("traffic sign" , 20 , 7 , "object" , 3 , False , False , (220 , 220 , 0 )),
79
+ "vegetation" : CityscapesClass ("vegetation" , 21 , 8 , "nature" , 4 , False , False , (107 , 142 , 35 )),
80
+ "terrain" : CityscapesClass ("terrain" , 22 , 9 , "nature" , 4 , False , False , (152 , 251 , 152 )),
81
+ "sky" : CityscapesClass ("sky" , 23 , 10 , "sky" , 5 , False , False , (70 , 130 , 180 )),
82
+ "person" : CityscapesClass ("person" , 24 , 11 , "human" , 6 , True , False , (220 , 20 , 60 )),
83
+ "rider" : CityscapesClass ("rider" , 25 , 12 , "human" , 6 , True , False , (255 , 0 , 0 )),
84
+ "car" : CityscapesClass ("car" , 26 , 13 , "vehicle" , 7 , True , False , (0 , 0 , 142 )),
85
+ "truck" : CityscapesClass ("truck" , 27 , 14 , "vehicle" , 7 , True , False , (0 , 0 , 70 )),
86
+ "bus" : CityscapesClass ("bus" , 28 , 15 , "vehicle" , 7 , True , False , (0 , 60 , 100 )),
87
+ "caravan" : CityscapesClass ("caravan" , 29 , 255 , "vehicle" , 7 , True , True , (0 , 0 , 90 )),
88
+ "trailer" : CityscapesClass ("trailer" , 30 , 255 , "vehicle" , 7 , True , True , (0 , 0 , 110 )),
89
+ "train" : CityscapesClass ("train" , 31 , 16 , "vehicle" , 7 , True , False , (0 , 80 , 100 )),
90
+ "motorcycle" : CityscapesClass ("motorcycle" , 32 , 17 , "vehicle" , 7 , True , False , (0 , 0 , 230 )),
91
+ "bicycle" : CityscapesClass ("bicycle" , 33 , 18 , "vehicle" , 7 , True , False , (119 , 11 , 32 )),
92
+ "license plate" : CityscapesClass ("license plate" , - 1 , - 1 , "vehicle" , 7 , False , True , (0 , 0 , 142 )),
93
+ }
94
+ )
95
+
47
96
def _make_info (self ) -> DatasetInfo :
48
97
name = "cityscapes"
49
- categories = None
50
98
51
99
return CityscapesDatasetInfo (
52
100
name ,
53
- categories = categories ,
101
+ categories = list ( self . categories_to_details . keys ()) ,
54
102
homepage = "http://www.cityscapes-dataset.com/" ,
55
103
valid_options = dict (
56
104
split = ("train" , "val" , "test" , "train_extra" ),
57
105
mode = ("fine" , "coarse" ),
58
- # target_type=("instance", "semantic", "polygon", "color")
59
106
),
107
+ extra = dict (classname_to_details = self .categories_to_details ),
60
108
)
61
109
62
110
_FILES_CHECKSUMS = {
@@ -67,8 +115,9 @@ def _make_info(self) -> DatasetInfo:
67
115
}
68
116
69
117
def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
118
+ resources : List [OnlineResource ] = []
70
119
if config .mode == "fine" :
71
- resources = [
120
+ resources + = [
72
121
CityscapesResource (
73
122
file_name = "leftImg8bit_trainvaltest.zip" ,
74
123
sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainvaltest.zip" ],
@@ -78,20 +127,22 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
78
127
),
79
128
]
80
129
else :
81
- resources = [
130
+ split_label = "trainextra" if config .split == "train_extra" else "trainvaltest"
131
+ resources += [
82
132
CityscapesResource (
83
- file_name = "leftImg8bit_trainextra.zip" , sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainextra.zip" ]
133
+ file_name = f"leftImg8bit_{ split_label } .zip" ,
134
+ sha256 = self ._FILES_CHECKSUMS [f"leftImg8bit_{ split_label } .zip" ],
84
135
),
85
136
CityscapesResource (file_name = "gtCoarse.zip" , sha256 = self ._FILES_CHECKSUMS ["gtCoarse.zip" ]),
86
137
]
87
138
return resources
88
139
89
- def _filter_split_images (self , data , * , req_split : str ):
140
+ def _filter_split_images (self , data : Tuple [ str , Any ], * , req_split : str ) -> bool :
90
141
path = Path (data [0 ])
91
142
split = path .parent .parts [- 2 ]
92
143
return split == req_split and ".png" == path .suffix
93
144
94
- def _filter_classify_targets (self , data , * , req_split : str ):
145
+ def _filter_classify_targets (self , data : Tuple [ str , Any ], * , req_split : str ) -> Optional [ int ] :
95
146
path = Path (data [0 ])
96
147
name = path .name
97
148
split = path .parent .parts [- 2 ]
@@ -103,7 +154,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
103
154
return i
104
155
return None
105
156
106
- def _prepare_sample (self , data ) :
157
+ def _prepare_sample (self , data : Tuple [ Tuple [ str , Any ], Any ]) -> Dict [ str , Any ] :
107
158
(img_path , img_data ), target_data = data
108
159
109
160
color_path , color_data = target_data [1 ]
@@ -112,7 +163,7 @@ def _prepare_sample(self, data):
112
163
target_data = target_data [0 ]
113
164
label_path , label_data = target_data [1 ]
114
165
target_data = target_data [0 ]
115
- instance_path , instance_data = target_data
166
+ instances_path , instance_data = target_data
116
167
117
168
return dict (
118
169
image_path = img_path ,
@@ -123,7 +174,7 @@ def _prepare_sample(self, data):
123
174
polygon = polygon_data ,
124
175
segmentation_path = label_path ,
125
176
segmentation = EncodedImage .from_file (label_data ),
126
- instances_path = color_path ,
177
+ instances_path = instances_path ,
127
178
instances = EncodedImage .from_file (instance_data ),
128
179
)
129
180
@@ -148,12 +199,12 @@ def _make_datapipe(
148
199
# targets_dps[2] is for json polygon, we have to decode them
149
200
targets_dps [2 ] = JsonParser (targets_dps [2 ])
150
201
151
- def img_key_fn (data ) :
202
+ def img_key_fn (data : Tuple [ str , Any ]) -> str :
152
203
stem = Path (data [0 ]).stem
153
204
stem = stem [: - len ("_leftImg8bit" )]
154
205
return stem
155
206
156
- def target_key_fn (data , level = 0 ) :
207
+ def target_key_fn (data : Tuple [ Any , Any ], level : int = 0 ) -> str :
157
208
path = data [0 ]
158
209
for _ in range (level ):
159
210
path = path [0 ]
@@ -179,4 +230,6 @@ def target_key_fn(data, level=0):
179
230
ref_key_fn = partial (target_key_fn , level = len (targets_dps ) - 1 ),
180
231
buffer_size = INFINITE_BUFFER_SIZE ,
181
232
)
233
+ samples = hint_sharding (samples )
234
+ samples = hint_shuffling (samples )
182
235
return Mapper (samples , fn = self ._prepare_sample )
0 commit comments