1
1
from functools import partial
2
2
from pathlib import Path
3
- from typing import Any , Dict , List
3
+ from typing import Any , Dict , List , Optional , Tuple
4
4
5
5
from torchdata .datapipes .iter import IterDataPipe , Mapper , Filter , Demultiplexer , IterKeyZipper , JsonParser
6
6
from torchvision .prototype .datasets .utils import (
10
10
ManualDownloadResource ,
11
11
OnlineResource ,
12
12
)
13
- from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
13
+ from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE , hint_sharding , hint_shuffling
14
14
from torchvision .prototype .features import EncodedImage
15
15
16
16
@@ -55,7 +55,6 @@ def _make_info(self) -> DatasetInfo:
55
55
valid_options = dict (
56
56
split = ("train" , "val" , "test" , "train_extra" ),
57
57
mode = ("fine" , "coarse" ),
58
- # target_type=("instance", "semantic", "polygon", "color")
59
58
),
60
59
)
61
60
@@ -67,8 +66,9 @@ def _make_info(self) -> DatasetInfo:
67
66
}
68
67
69
68
def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
69
+ resources : List [OnlineResource ] = []
70
70
if config .mode == "fine" :
71
- resources = [
71
+ resources + = [
72
72
CityscapesResource (
73
73
file_name = "leftImg8bit_trainvaltest.zip" ,
74
74
sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainvaltest.zip" ],
@@ -78,20 +78,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
78
78
),
79
79
]
80
80
else :
81
- resources = [
81
+ split_label = "trainextra" if config .split == "train_extra" else "trainvaltest"
82
+ resources += [
82
83
CityscapesResource (
83
- file_name = "leftImg8bit_trainextra .zip" , sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainextra .zip" ]
84
+ file_name = f"leftImg8bit_ { split_label } .zip" , sha256 = self ._FILES_CHECKSUMS [f"leftImg8bit_ { split_label } .zip" ]
84
85
),
85
86
CityscapesResource (file_name = "gtCoarse.zip" , sha256 = self ._FILES_CHECKSUMS ["gtCoarse.zip" ]),
86
87
]
87
88
return resources
88
89
89
- def _filter_split_images (self , data , * , req_split : str ):
90
+ def _filter_split_images (self , data : Tuple [ str , Any ], * , req_split : str ) -> bool :
90
91
path = Path (data [0 ])
91
92
split = path .parent .parts [- 2 ]
92
93
return split == req_split and ".png" == path .suffix
93
94
94
- def _filter_classify_targets (self , data , * , req_split : str ):
95
+ def _filter_classify_targets (self , data : Tuple [ str , Any ], * , req_split : str ) -> Optional [ int ] :
95
96
path = Path (data [0 ])
96
97
name = path .name
97
98
split = path .parent .parts [- 2 ]
@@ -103,7 +104,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
103
104
return i
104
105
return None
105
106
106
- def _prepare_sample (self , data ) :
107
+ def _prepare_sample (self , data : Tuple [ Tuple [ str , Any ], Any ]) -> Dict [ str , Any ] :
107
108
(img_path , img_data ), target_data = data
108
109
109
110
color_path , color_data = target_data [1 ]
@@ -112,7 +113,7 @@ def _prepare_sample(self, data):
112
113
target_data = target_data [0 ]
113
114
label_path , label_data = target_data [1 ]
114
115
target_data = target_data [0 ]
115
- instance_path , instance_data = target_data
116
+ instances_path , instance_data = target_data
116
117
117
118
return dict (
118
119
image_path = img_path ,
@@ -123,7 +124,7 @@ def _prepare_sample(self, data):
123
124
polygon = polygon_data ,
124
125
segmentation_path = label_path ,
125
126
segmentation = EncodedImage .from_file (label_data ),
126
- instances_path = color_path ,
127
+ instances_path = instances_path ,
127
128
instances = EncodedImage .from_file (instance_data ),
128
129
)
129
130
@@ -148,18 +149,20 @@ def _make_datapipe(
148
149
# targets_dps[2] is for json polygon, we have to decode them
149
150
targets_dps [2 ] = JsonParser (targets_dps [2 ])
150
151
151
- def img_key_fn (data ) :
152
+ def img_key_fn (data : Tuple [ str , Any ]) -> str :
152
153
stem = Path (data [0 ]).stem
153
154
stem = stem [: - len ("_leftImg8bit" )]
155
+ print ("img_key stem:" , stem , "<-" , Path (data [0 ]).name )
154
156
return stem
155
157
156
- def target_key_fn (data , level = 0 ) :
158
+ def target_key_fn (data : Tuple [ Any , Any ], level : int = 0 ) -> str :
157
159
path = data [0 ]
158
160
for _ in range (level ):
159
161
path = path [0 ]
160
162
stem = Path (path ).stem
161
163
i = stem .rfind ("_gt" )
162
164
stem = stem [:i ]
165
+ print ("target_key stem:" , stem , level , "<-" , Path (path ).name )
163
166
return stem
164
167
165
168
zipped_targets_dp = targets_dps [0 ]
@@ -179,4 +182,6 @@ def target_key_fn(data, level=0):
179
182
ref_key_fn = partial (target_key_fn , level = len (targets_dps ) - 1 ),
180
183
buffer_size = INFINITE_BUFFER_SIZE ,
181
184
)
185
+ samples = hint_sharding (samples )
186
+ samples = hint_shuffling (samples )
182
187
return Mapper (samples , fn = self ._prepare_sample )
0 commit comments