Skip to content

Migrate Oxford Pets prototype dataset #5764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,9 +1230,9 @@ def generate(self, root):
return num_samples_map


# @register_mock
def oxford_iiit_pet(info, root, config):
return OxfordIIITPetMockData.generate(root)[config.split]
@register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test")))
def oxford_iiit_pet(root, config):
return OxfordIIITPetMockData.generate(root)[config["split"]]


class _CUB200MockData:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
Expand Down
66 changes: 40 additions & 26 deletions torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import enum
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union

from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
Dataset2,
DatasetInfo,
HttpResource,
OnlineResource,
Expand All @@ -14,29 +13,45 @@
INFINITE_BUFFER_SIZE,
hint_sharding,
hint_shuffling,
BUILTIN_DIR,
getitem,
path_accessor,
path_comparator,
)
from torchvision.prototype.features import Label, EncodedImage

from .._api import register_dataset, register_info

class OxfordIITPetDemux(enum.IntEnum):

NAME = "oxford-iiit-pet"


class OxfordIIITPetDemux(enum.IntEnum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1


class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"oxford-iiit-pet",
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict(
split=("trainval", "test"),
),
)
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")
categories = [c[0] for c in categories]
return dict(categories=categories)


def resources(self, config: DatasetConfig) -> List[OnlineResource]:
@register_dataset(NAME)
class OxfordIIITPet(Dataset2):
"""Oxford IIIT Pet Dataset
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
"""

def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"trainval", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)

def _resources(self) -> List[OnlineResource]:
images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
Expand All @@ -51,8 +66,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:

def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return {
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
"annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name)

def _filter_images(self, data: Tuple[str, Any]) -> bool:
Expand All @@ -70,17 +85,15 @@ def _prepare_sample(
image_path, image_buffer = image_data

return dict(
label=Label(int(classification_data["label"]) - 1, categories=self.categories),
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
)

def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps

images_dp = Filter(images_dp, self._filter_images)
Expand All @@ -93,9 +106,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
)

split_and_classification_dp = Filter(
split_and_classification_dp, path_comparator("name", f"{config.split}.txt")
)
split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt"))
split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
)
Expand All @@ -122,13 +133,13 @@ def _make_datapipe(
return Mapper(dp, self._prepare_sample)

def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION

def _generate_categories(self, root: pathlib.Path) -> List[str]:
def _generate_categories(self) -> List[str]:
config = self.default_config
resources = self.resources(config)

dp = resources[1].load(root)
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_split_and_classification_anns)
dp = Filter(dp, path_comparator("name", f"{config.split}.txt"))
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
Expand All @@ -138,3 +149,6 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]:
*sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))
)
return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories]

def __len__(self) -> int:
return 3_680 if self._split == "trainval" else 3_669