11from functools import partial
22from pathlib import Path
3- from typing import Any , Dict , List
3+ from typing import Any , Dict , List , Optional , Tuple
44
55from torchdata .datapipes .iter import IterDataPipe , Mapper , Filter , Demultiplexer , IterKeyZipper , JsonParser
66from torchvision .prototype .datasets .utils import (
1010 ManualDownloadResource ,
1111 OnlineResource ,
1212)
13- from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
13+ from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE , hint_sharding , hint_shuffling
1414from torchvision .prototype .features import EncodedImage
1515
1616
@@ -55,7 +55,6 @@ def _make_info(self) -> DatasetInfo:
5555 valid_options = dict (
5656 split = ("train" , "val" , "test" , "train_extra" ),
5757 mode = ("fine" , "coarse" ),
58- # target_type=("instance", "semantic", "polygon", "color")
5958 ),
6059 )
6160
@@ -67,8 +66,9 @@ def _make_info(self) -> DatasetInfo:
6766 }
6867
6968 def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
69+ resources : List [OnlineResource ] = []
7070 if config .mode == "fine" :
71- resources = [
71+ resources + = [
7272 CityscapesResource (
7373 file_name = "leftImg8bit_trainvaltest.zip" ,
7474 sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainvaltest.zip" ],
@@ -78,20 +78,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
7878 ),
7979 ]
8080 else :
81- resources = [
81+ split_label = "trainextra" if config .split == "train_extra" else "trainvaltest"
82+ resources += [
8283 CityscapesResource (
83- file_name = "leftImg8bit_trainextra .zip" , sha256 = self ._FILES_CHECKSUMS ["leftImg8bit_trainextra .zip" ]
84+ file_name = f"leftImg8bit_ { split_label } .zip" , sha256 = self ._FILES_CHECKSUMS [f"leftImg8bit_ { split_label } .zip" ]
8485 ),
8586 CityscapesResource (file_name = "gtCoarse.zip" , sha256 = self ._FILES_CHECKSUMS ["gtCoarse.zip" ]),
8687 ]
8788 return resources
8889
89- def _filter_split_images (self , data , * , req_split : str ):
90+ def _filter_split_images (self , data : Tuple [ str , Any ], * , req_split : str ) -> bool :
9091 path = Path (data [0 ])
9192 split = path .parent .parts [- 2 ]
9293 return split == req_split and ".png" == path .suffix
9394
94- def _filter_classify_targets (self , data , * , req_split : str ):
95+ def _filter_classify_targets (self , data : Tuple [ str , Any ], * , req_split : str ) -> Optional [ int ] :
9596 path = Path (data [0 ])
9697 name = path .name
9798 split = path .parent .parts [- 2 ]
@@ -103,7 +104,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
103104 return i
104105 return None
105106
106- def _prepare_sample (self , data ) :
107+ def _prepare_sample (self , data : Tuple [ Tuple [ str , Any ], Any ]) -> Dict [ str , Any ] :
107108 (img_path , img_data ), target_data = data
108109
109110 color_path , color_data = target_data [1 ]
@@ -112,7 +113,7 @@ def _prepare_sample(self, data):
112113 target_data = target_data [0 ]
113114 label_path , label_data = target_data [1 ]
114115 target_data = target_data [0 ]
115- instance_path , instance_data = target_data
116+ instances_path , instance_data = target_data
116117
117118 return dict (
118119 image_path = img_path ,
@@ -123,7 +124,7 @@ def _prepare_sample(self, data):
123124 polygon = polygon_data ,
124125 segmentation_path = label_path ,
125126 segmentation = EncodedImage .from_file (label_data ),
126- instances_path = color_path ,
127+ instances_path = instances_path ,
127128 instances = EncodedImage .from_file (instance_data ),
128129 )
129130
@@ -148,18 +149,20 @@ def _make_datapipe(
148149 # targets_dps[2] is for json polygon, we have to decode them
149150 targets_dps [2 ] = JsonParser (targets_dps [2 ])
150151
151- def img_key_fn (data ) :
152+ def img_key_fn (data : Tuple [ str , Any ]) -> str :
152153 stem = Path (data [0 ]).stem
153154 stem = stem [: - len ("_leftImg8bit" )]
155+ print ("img_key stem:" , stem , "<-" , Path (data [0 ]).name )
154156 return stem
155157
156- def target_key_fn (data , level = 0 ) :
158+ def target_key_fn (data : Tuple [ Any , Any ], level : int = 0 ) -> str :
157159 path = data [0 ]
158160 for _ in range (level ):
159161 path = path [0 ]
160162 stem = Path (path ).stem
161163 i = stem .rfind ("_gt" )
162164 stem = stem [:i ]
165+ print ("target_key stem:" , stem , level , "<-" , Path (path ).name )
163166 return stem
164167
165168 zipped_targets_dp = targets_dps [0 ]
@@ -179,4 +182,6 @@ def target_key_fn(data, level=0):
179182 ref_key_fn = partial (target_key_fn , level = len (targets_dps ) - 1 ),
180183 buffer_size = INFINITE_BUFFER_SIZE ,
181184 )
185+ samples = hint_sharding (samples )
186+ samples = hint_shuffling (samples )
182187 return Mapper (samples , fn = self ._prepare_sample )
0 commit comments