Skip to content

Commit

Permalink
Fixed mypy issue and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Feb 21, 2022
1 parent 1ccf7b0 commit b3f4c0a
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 13 deletions.
93 changes: 93 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,96 @@ def pcam(info, root, config):
compressed_file.write(compressed_data)

return num_images


class CityScapesMockData:

_ARCHIVE_NAMES = {
("Coarse", "train"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
("Coarse", "train_extra"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainextra.zip", "leftImg8bit")],
("Coarse", "val"): [("gtCoarse.zip", "gtCoarse"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
("Fine", "train"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
("Fine", "test"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
("Fine", "val"): [("gtFine_trainvaltest.zip", "gtFine"), ("leftImg8bit_trainvaltest.zip", "leftImg8bit")],
}

@classmethod
def generate(cls, root, config):

mode = config.mode.capitalize()
split = config.split

if split in ["train", "train_extra"]:
cities = ["bochum", "bremen"]
num_samples = 3
else:
cities = ["bochum"]
num_samples = 2

polygon_target = {
"imgHeight": 1024,
"imgWidth": 2048,
"objects": [
{
"label": "sky",
"polygon": [
[1241, 0],
[1234, 156],
[1478, 197],
[1611, 172],
[1606, 0],
],
},
{
"label": "road",
"polygon": [
[0, 448],
[1331, 274],
[1473, 265],
[2047, 605],
[2047, 1023],
[0, 1023],
],
},
],
}

gt_dir = root / f"gt{mode}"

for city in cities:

def make_image(name, size=10):
create_image_folder(
root=gt_dir / split,
name=city,
file_name_fn=lambda idx: name.format(idx=idx),
size=size,
num_examples=num_samples,
)

make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_instanceIds.png")
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_labelIds.png")
make_image(f"{city}_000000_00000" + "{idx}" + f"_gt{mode}_color.png", size=(4, 10, 10))

for idx in range(num_samples):
polygon_target_name = gt_dir / split / city / f"{city}_000000_00000{idx}_gt{mode}_polygons.json"
with open(polygon_target_name, "w") as outfile:
json.dump(polygon_target, outfile)

# Create leftImg8bit folder
for city in cities:
create_image_folder(
root=root / "leftImg8bit" / split,
name=city,
file_name_fn=lambda idx: f"{city}_000000_00000{idx}_leftImg8bit.png",
num_examples=num_samples,
)

for zip_name, folder_name in cls._ARCHIVE_NAMES[(mode, split)]:
make_zip(root, zip_name, folder_name)
return len(cities) * num_samples


@register_mock
def cityscapes(info, root, config):
return CityScapesMockData.generate(root, config)
31 changes: 18 additions & 13 deletions torchvision/prototype/datasets/_builtin/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple

from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, Demultiplexer, IterKeyZipper, JsonParser
from torchvision.prototype.datasets.utils import (
Expand All @@ -10,7 +10,7 @@
ManualDownloadResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedImage


Expand Down Expand Up @@ -55,7 +55,6 @@ def _make_info(self) -> DatasetInfo:
valid_options=dict(
split=("train", "val", "test", "train_extra"),
mode=("fine", "coarse"),
# target_type=("instance", "semantic", "polygon", "color")
),
)

Expand All @@ -67,8 +66,9 @@ def _make_info(self) -> DatasetInfo:
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
resources: List[OnlineResource] = []
if config.mode == "fine":
resources = [
resources += [
CityscapesResource(
file_name="leftImg8bit_trainvaltest.zip",
sha256=self._FILES_CHECKSUMS["leftImg8bit_trainvaltest.zip"],
Expand All @@ -78,20 +78,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
),
]
else:
resources = [
split_label = "trainextra" if config.split == "train_extra" else "trainvaltest"
resources += [
CityscapesResource(
file_name="leftImg8bit_trainextra.zip", sha256=self._FILES_CHECKSUMS["leftImg8bit_trainextra.zip"]
file_name=f"leftImg8bit_{split_label}.zip", sha256=self._FILES_CHECKSUMS[f"leftImg8bit_{split_label}.zip"]
),
CityscapesResource(file_name="gtCoarse.zip", sha256=self._FILES_CHECKSUMS["gtCoarse.zip"]),
]
return resources

def _filter_split_images(self, data, *, req_split: str):
def _filter_split_images(self, data: Tuple[str, Any], *, req_split: str) -> bool:
path = Path(data[0])
split = path.parent.parts[-2]
return split == req_split and ".png" == path.suffix

def _filter_classify_targets(self, data, *, req_split: str):
def _filter_classify_targets(self, data: Tuple[str, Any], *, req_split: str) -> Optional[int]:
path = Path(data[0])
name = path.name
split = path.parent.parts[-2]
Expand All @@ -103,7 +104,7 @@ def _filter_classify_targets(self, data, *, req_split: str):
return i
return None

def _prepare_sample(self, data):
def _prepare_sample(self, data: Tuple[Tuple[str, Any], Any]) -> Dict[str, Any]:
(img_path, img_data), target_data = data

color_path, color_data = target_data[1]
Expand All @@ -112,7 +113,7 @@ def _prepare_sample(self, data):
target_data = target_data[0]
label_path, label_data = target_data[1]
target_data = target_data[0]
instance_path, instance_data = target_data
instances_path, instance_data = target_data

return dict(
image_path=img_path,
Expand All @@ -123,7 +124,7 @@ def _prepare_sample(self, data):
polygon=polygon_data,
segmentation_path=label_path,
segmentation=EncodedImage.from_file(label_data),
instances_path=color_path,
instances_path=instances_path,
instances=EncodedImage.from_file(instance_data),
)

Expand All @@ -148,18 +149,20 @@ def _make_datapipe(
# targets_dps[2] is for json polygon, we have to decode them
targets_dps[2] = JsonParser(targets_dps[2])

def img_key_fn(data):
def img_key_fn(data: Tuple[str, Any]) -> str:
stem = Path(data[0]).stem
stem = stem[: -len("_leftImg8bit")]
print("img_key stem:", stem, "<-", Path(data[0]).name)
return stem

def target_key_fn(data, level=0):
def target_key_fn(data: Tuple[Any, Any], level: int = 0) -> str:
path = data[0]
for _ in range(level):
path = path[0]
stem = Path(path).stem
i = stem.rfind("_gt")
stem = stem[:i]
print("target_key stem:", stem, level, "<-", Path(path).name)
return stem

zipped_targets_dp = targets_dps[0]
Expand All @@ -179,4 +182,6 @@ def target_key_fn(data, level=0):
ref_key_fn=partial(target_key_fn, level=len(targets_dps) - 1),
buffer_size=INFINITE_BUFFER_SIZE,
)
samples = hint_sharding(samples)
samples = hint_shuffling(samples)
return Mapper(samples, fn=self._prepare_sample)

0 comments on commit b3f4c0a

Please sign in to comment.