Skip to content

Commit b3f4c0a

Browse files
committed
Fixed mypy issue and added tests
1 parent 1ccf7b0 commit b3f4c0a

File tree

2 files changed

+111
-13
lines changed

2 files changed

+111
-13
lines changed

test/builtin_dataset_mocks.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,3 +1344,96 @@ def pcam(info, root, config):
13441344
compressed_file.write(compressed_data)
13451345

13461346
return num_images
1347+
1348+
1349+
class CityScapesMockData:
1350+
1351+
_ARCHIVE_NAMES = {
1352+
("Coarse", "train"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1353+
("Coarse", "train_extra"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainextra.zip", "leftImg8bit")],
1354+
("Coarse", "val"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1355+
("Fine", "train"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1356+
("Fine", "test"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1357+
("Fine", "val"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1358+
}
1359+
1360+
@classmethod
1361+
def generate(cls, root, config):
1362+
1363+
mode = config.mode.capitalize()
1364+
split = config.split
1365+
1366+
if split in ["train", "train_extra"]:
1367+
cities = ["bochum", "bremen"]
1368+
num_samples = 3
1369+
else:
1370+
cities = ["bochum"]
1371+
num_samples = 2
1372+
1373+
polygon_target = {
1374+
"imgHeight": 1024,
1375+
"imgWidth": 2048,
1376+
"objects": [
1377+
{
1378+
"label": "sky",
1379+
"polygon": [
1380+
[1241, 0],
1381+
[1234, 156],
1382+
[1478, 197],
1383+
[1611, 172],
1384+
[1606, 0],
1385+
],
1386+
},
1387+
{
1388+
"label": "road",
1389+
"polygon": [
1390+
[0, 448],
1391+
[1331, 274],
1392+
[1473, 265],
1393+
[2047, 605],
1394+
[2047, 1023],
1395+
[0, 1023],
1396+
],
1397+
},
1398+
],
1399+
}
1400+
1401+
gt_dir = root / f"gt{mode}"
1402+
1403+
for city in cities:
1404+
1405+
def make_image(name, size=10):
1406+
create_image_folder(
1407+
root=gt_dir / split,
1408+
name=city,
1409+
file_name_fn=lambda idx: name.format(idx=idx),
1410+
size=size,
1411+
num_examples=num_samples,
1412+
)
1413+
1414+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_instanceIds.png")
1415+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_labelIds.png")
1416+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_color.png", size=(4, 10, 10))
1417+
1418+
for idx in range(num_samples):
1419+
polygon_target_name = gt_dir / split / city / f"{city}_000000_00000{idx}_gt{mode}_polygons.json"
1420+
with open(polygon_target_name, "w") as outfile:
1421+
json.dump(polygon_target, outfile)
1422+
1423+
# Create leftImg8bit folder
1424+
for city in cities:
1425+
create_image_folder(
1426+
root=root / "leftImg8bit" / split,
1427+
name=city,
1428+
file_name_fn=lambda idx: f"{city}_000000_00000{idx}_leftImg8bit.png",
1429+
num_examples=num_samples,
1430+
)
1431+
1432+
for zip_name, folder_name in cls._ARCHIVE_NAMES[(mode, split)]:
1433+
make_zip(root, zip_name, folder_name)
1434+
return len(cities) * num_samples
1435+
1436+
1437+
@register_mock
1438+
def cityscapes(info, root, config):
1439+
return CityScapesMockData.generate(root, config)

torchvision/prototype/datasets/_builtin/cityscapes.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import partial
22
from pathlib import Path
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional, Tuple
44

55
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, Demultiplexer, IterKeyZipper, JsonParser
66
from torchvision.prototype.datasets.utils import (
@@ -10,7 +10,7 @@
1010
ManualDownloadResource,
1111
OnlineResource,
1212
)
13-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
13+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
1414
from torchvision.prototype.features import EncodedImage
1515

1616

@@ -55,7 +55,6 @@ def _make_info(self) -> DatasetInfo:
5555
valid_options=dict(
5656
split=("train", "val", "test", "train_extra"),
5757
mode=("fine", "coarse"),
58-
# target_type=("instance", "semantic", "polygon", "color")
5958
),
6059
)
6160

@@ -67,8 +66,9 @@ def _make_info(self) -> DatasetInfo:
6766
}
6867

6968
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
69+
resources: List[OnlineResource] = []
7070
if config.mode == "fine":
71-
resources = [
71+
resources += [
7272
CityscapesResource(
7373
file_name="leftImg8bit_trainvaltest.zip",
7474
sha256=self._FILES_CHECKSUMS["leftImg8bit_trainvaltest.zip"],
@@ -78,20 +78,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
7878
),
7979
]
8080
else:
81-
resources = [
81+
split_label = "trainextra" if config.split == "train_extra" else "trainvaltest"
82+
resources += [
8283
CityscapesResource(
83-
file_name="leftImg8bit_trainextra.zip", sha256=self._FILES_CHECKSUMS["leftImg8bit_trainextra.zip"]
84+
file_name=f"leftImg8bit_{split_label}.zip", sha256=self._FILES_CHECKSUMS[f"leftImg8bit_{split_label}.zip"]
8485
),
8586
CityscapesResource(file_name="gtCoarse.zip", sha256=self._FILES_CHECKSUMS["gtCoarse.zip"]),
8687
]
8788
return resources
8889

89-
def _filter_split_images(self, data, *, req_split: str):
90+
def _filter_split_images(self, data: Tuple[str, Any], *, req_split: str) -> bool:
9091
path = Path(data[0])
9192
split = path.parent.parts[-2]
9293
return split == req_split and ".png" == path.suffix
9394

94-
def _filter_classify_targets(self, data, *, req_split: str):
95+
def _filter_classify_targets(self, data: Tuple[str, Any], *, req_split: str) -> Optional[int]:
9596
path = Path(data[0])
9697
name = path.name
9798
split = path.parent.parts[-2]
@@ -103,7 +104,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
103104
return i
104105
return None
105106

106-
def _prepare_sample(self, data):
107+
def _prepare_sample(self, data: Tuple[Tuple[str, Any], Any]) -> Dict[str, Any]:
107108
(img_path, img_data), target_data = data
108109

109110
color_path, color_data = target_data[1]
@@ -112,7 +113,7 @@ def _prepare_sample(self, data):
112113
target_data = target_data[0]
113114
label_path, label_data = target_data[1]
114115
target_data = target_data[0]
115-
instance_path, instance_data = target_data
116+
instances_path, instance_data = target_data
116117

117118
return dict(
118119
image_path=img_path,
@@ -123,7 +124,7 @@ def _prepare_sample(self, data):
123124
polygon=polygon_data,
124125
segmentation_path=label_path,
125126
segmentation=EncodedImage.from_file(label_data),
126-
instances_path=color_path,
127+
instances_path=instances_path,
127128
instances=EncodedImage.from_file(instance_data),
128129
)
129130

@@ -148,18 +149,20 @@ def _make_datapipe(
148149
# targets_dps[2] is for json polygon, we have to decode them
149150
targets_dps[2] = JsonParser(targets_dps[2])
150151

151-
def img_key_fn(data):
152+
def img_key_fn(data: Tuple[str, Any]) -> str:
152153
stem = Path(data[0]).stem
153154
stem = stem[: -len("_leftImg8bit")]
155+
print("img_key stem:", stem, "<-", Path(data[0]).name)
154156
return stem
155157

156-
def target_key_fn(data, level=0):
158+
def target_key_fn(data: Tuple[Any, Any], level: int = 0) -> str:
157159
path = data[0]
158160
for _ in range(level):
159161
path = path[0]
160162
stem = Path(path).stem
161163
i = stem.rfind("_gt")
162164
stem = stem[:i]
165+
print("target_key stem:", stem, level, "<-", Path(path).name)
163166
return stem
164167

165168
zipped_targets_dp = targets_dps[0]
@@ -179,4 +182,6 @@ def target_key_fn(data, level=0):
179182
ref_key_fn=partial(target_key_fn, level=len(targets_dps) - 1),
180183
buffer_size=INFINITE_BUFFER_SIZE,
181184
)
185+
samples = hint_sharding(samples)
186+
samples = hint_shuffling(samples)
182187
return Mapper(samples, fn=self._prepare_sample)

0 commit comments

Comments
 (0)