Skip to content

Commit e003b71

Browse files
committed
Fixed mypy issue and added tests and category infos
1 parent 1ccf7b0 commit e003b71

File tree

2 files changed

+161
-15
lines changed

2 files changed

+161
-15
lines changed

test/builtin_dataset_mocks.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,3 +1344,96 @@ def pcam(info, root, config):
13441344
compressed_file.write(compressed_data)
13451345

13461346
return num_images
1347+
1348+
1349+
class CityScapesMockData:
1350+
1351+
_ARCHIVE_NAMES = {
1352+
("Coarse", "train"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1353+
("Coarse", "train_extra"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainextra.zip", "leftImg8bit")],
1354+
("Coarse", "val"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1355+
("Fine", "train"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1356+
("Fine", "test"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1357+
("Fine", "val"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
1358+
}
1359+
1360+
@classmethod
1361+
def generate(cls, root, config):
1362+
1363+
mode = config.mode.capitalize()
1364+
split = config.split
1365+
1366+
if split in ["train", "train_extra"]:
1367+
cities = ["bochum", "bremen"]
1368+
num_samples = 3
1369+
else:
1370+
cities = ["bochum"]
1371+
num_samples = 2
1372+
1373+
polygon_target = {
1374+
"imgHeight": 1024,
1375+
"imgWidth": 2048,
1376+
"objects": [
1377+
{
1378+
"label": "sky",
1379+
"polygon": [
1380+
[1241, 0],
1381+
[1234, 156],
1382+
[1478, 197],
1383+
[1611, 172],
1384+
[1606, 0],
1385+
],
1386+
},
1387+
{
1388+
"label": "road",
1389+
"polygon": [
1390+
[0, 448],
1391+
[1331, 274],
1392+
[1473, 265],
1393+
[2047, 605],
1394+
[2047, 1023],
1395+
[0, 1023],
1396+
],
1397+
},
1398+
],
1399+
}
1400+
1401+
gt_dir = root / f"gt{mode}"
1402+
1403+
for city in cities:
1404+
1405+
def make_image(name, size=10):
1406+
create_image_folder(
1407+
root=gt_dir / split,
1408+
name=city,
1409+
file_name_fn=lambda idx: name.format(idx=idx),
1410+
size=size,
1411+
num_examples=num_samples,
1412+
)
1413+
1414+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_instanceIds.png")
1415+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_labelIds.png")
1416+
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_color.png", size=(4, 10, 10))
1417+
1418+
for idx in range(num_samples):
1419+
polygon_target_name = gt_dir / split / city / f"{city}_000000_00000{idx}_gt{mode}_polygons.json"
1420+
with open(polygon_target_name, "w") as outfile:
1421+
json.dump(polygon_target, outfile)
1422+
1423+
# Create leftImg8bit folder
1424+
for city in cities:
1425+
create_image_folder(
1426+
root=root / "leftImg8bit" / split,
1427+
name=city,
1428+
file_name_fn=lambda idx: f"{city}_000000_00000{idx}_leftImg8bit.png",
1429+
num_examples=num_samples,
1430+
)
1431+
1432+
for zip_name, folder_name in cls._ARCHIVE_NAMES[(mode, split)]:
1433+
make_zip(root, zip_name, folder_name)
1434+
return len(cities) * num_samples
1435+
1436+
1437+
@register_mock
1438+
def cityscapes(info, root, config):
1439+
return CityScapesMockData.generate(root, config)

torchvision/prototype/datasets/_builtin/cityscapes.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import namedtuple
12
from functools import partial
23
from pathlib import Path
3-
from typing import Any, Dict, List
4+
from typing import Any, Dict, List, Optional, Tuple
45

56
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, Demultiplexer, IterKeyZipper, JsonParser
67
from torchvision.prototype.datasets.utils import (
@@ -10,8 +11,9 @@
1011
ManualDownloadResource,
1112
OnlineResource,
1213
)
13-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
14+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
1415
from torchvision.prototype.features import EncodedImage
16+
from torchvision.prototype.utils._internal import FrozenMapping
1517

1618

1719
class CityscapesDatasetInfo(DatasetInfo):
@@ -43,20 +45,66 @@ def __init__(self, **kwargs: Any) -> None:
4345
)
4446

4547

48+
CityscapesClass = namedtuple(
49+
"CityscapesClass",
50+
["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
51+
)
52+
53+
4654
class Cityscapes(Dataset):
55+
56+
categories_to_details: FrozenMapping = FrozenMapping(
57+
{
58+
"unlabeled": CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
59+
"ego vehicle": CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
60+
"rectification border": CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
61+
"out of roi": CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
62+
"static": CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
63+
"dynamic": CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
64+
"ground": CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
65+
"road": CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
66+
"sidewalk": CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
67+
"parking": CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
68+
"rail track": CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
69+
"building": CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
70+
"wall": CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
71+
"fence": CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
72+
"guard rail": CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
73+
"bridge": CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
74+
"tunnel": CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
75+
"pole": CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
76+
"polegroup": CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
77+
"traffic light": CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
78+
"traffic sign": CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
79+
"vegetation": CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
80+
"terrain": CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
81+
"sky": CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
82+
"person": CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
83+
"rider": CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
84+
"car": CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
85+
"truck": CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
86+
"bus": CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
87+
"caravan": CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
88+
"trailer": CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
89+
"train": CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
90+
"motorcycle": CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
91+
"bicycle": CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
92+
"license plate": CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
93+
}
94+
)
95+
4796
def _make_info(self) -> DatasetInfo:
4897
name = "cityscapes"
49-
categories = None
5098

5199
return CityscapesDatasetInfo(
52100
name,
53-
categories=categories,
101+
categories=list(self.categories_to_details.keys()),
54102
homepage="http://www.cityscapes-dataset.com/",
55103
valid_options=dict(
56104
split=("train", "val", "test", "train_extra"),
57105
mode=("fine", "coarse"),
58-
# target_type=("instance", "semantic", "polygon", "color")
59106
),
107+
extra=dict(classname_to_details=self.categories_to_details),
60108
)
61109

62110
_FILES_CHECKSUMS = {
@@ -67,8 +115,9 @@ def _make_info(self) -> DatasetInfo:
67115
}
68116

69117
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
118+
resources: List[OnlineResource] = []
70119
if config.mode == "fine":
71-
resources = [
120+
resources += [
72121
CityscapesResource(
73122
file_name="leftImg8bit_trainvaltest.zip",
74123
sha256=self._FILES_CHECKSUMS["leftImg8bit_trainvaltest.zip"],
@@ -78,20 +127,22 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
78127
),
79128
]
80129
else:
81-
resources = [
130+
split_label = "trainextra" if config.split == "train_extra" else "trainvaltest"
131+
resources += [
82132
CityscapesResource(
83-
file_name="leftImg8bit_trainextra.zip", sha256=self._FILES_CHECKSUMS["leftImg8bit_trainextra.zip"]
133+
file_name=f"leftImg8bit_{split_label}.zip",
134+
sha256=self._FILES_CHECKSUMS[f"leftImg8bit_{split_label}.zip"],
84135
),
85136
CityscapesResource(file_name="gtCoarse.zip", sha256=self._FILES_CHECKSUMS["gtCoarse.zip"]),
86137
]
87138
return resources
88139

89-
def _filter_split_images(self, data, *, req_split: str):
140+
def _filter_split_images(self, data: Tuple[str, Any], *, req_split: str) -> bool:
90141
path = Path(data[0])
91142
split = path.parent.parts[-2]
92143
return split == req_split and ".png" == path.suffix
93144

94-
def _filter_classify_targets(self, data, *, req_split: str):
145+
def _filter_classify_targets(self, data: Tuple[str, Any], *, req_split: str) -> Optional[int]:
95146
path = Path(data[0])
96147
name = path.name
97148
split = path.parent.parts[-2]
@@ -103,7 +154,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
103154
return i
104155
return None
105156

106-
def _prepare_sample(self, data):
157+
def _prepare_sample(self, data: Tuple[Tuple[str, Any], Any]) -> Dict[str, Any]:
107158
(img_path, img_data), target_data = data
108159

109160
color_path, color_data = target_data[1]
@@ -112,7 +163,7 @@ def _prepare_sample(self, data):
112163
target_data = target_data[0]
113164
label_path, label_data = target_data[1]
114165
target_data = target_data[0]
115-
instance_path, instance_data = target_data
166+
instances_path, instance_data = target_data
116167

117168
return dict(
118169
image_path=img_path,
@@ -123,7 +174,7 @@ def _prepare_sample(self, data):
123174
polygon=polygon_data,
124175
segmentation_path=label_path,
125176
segmentation=EncodedImage.from_file(label_data),
126-
instances_path=color_path,
177+
instances_path=instances_path,
127178
instances=EncodedImage.from_file(instance_data),
128179
)
129180

@@ -148,12 +199,12 @@ def _make_datapipe(
148199
# targets_dps[2] is for json polygon, we have to decode them
149200
targets_dps[2] = JsonParser(targets_dps[2])
150201

151-
def img_key_fn(data):
202+
def img_key_fn(data: Tuple[str, Any]) -> str:
152203
stem = Path(data[0]).stem
153204
stem = stem[: -len("_leftImg8bit")]
154205
return stem
155206

156-
def target_key_fn(data, level=0):
207+
def target_key_fn(data: Tuple[Any, Any], level: int = 0) -> str:
157208
path = data[0]
158209
for _ in range(level):
159210
path = path[0]
@@ -179,4 +230,6 @@ def target_key_fn(data, level=0):
179230
ref_key_fn=partial(target_key_fn, level=len(targets_dps) - 1),
180231
buffer_size=INFINITE_BUFFER_SIZE,
181232
)
233+
samples = hint_sharding(samples)
234+
samples = hint_shuffling(samples)
182235
return Mapper(samples, fn=self._prepare_sample)

0 commit comments

Comments
 (0)