Skip to content

Commit 27104fe

Browse files
NicolasHugpmeier
andauthored
Migrate PCAM prototype dataset (#5745)
* Port PCAM * skip_integrity_check * Update torchvision/prototype/datasets/_builtin/pcam.py Co-authored-by: Philip Meier <github.pmeier@posteo.de> * Address comments Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent 42bc682 commit 27104fe

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

test/builtin_dataset_mocks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,13 +1430,13 @@ def svhn(info, root, config):
14301430
return num_samples
14311431

14321432

1433-
# @register_mock
1434-
def pcam(info, root, config):
1433+
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
1434+
def pcam(root, config):
14351435
import h5py
14361436

1437-
num_images = {"train": 2, "test": 3, "val": 4}[config.split]
1437+
num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]
14381438

1439-
split = "valid" if config.split == "val" else config.split
1439+
split = "valid" if config["split"] == "val" else config["split"]
14401440

14411441
images_io = io.BytesIO()
14421442
with h5py.File(images_io, "w") as f:

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import io
2+
import pathlib
23
from collections import namedtuple
3-
from typing import Any, Dict, List, Optional, Tuple, Iterator
4+
from typing import Any, Dict, List, Optional, Tuple, Iterator, Union
5+
from unicodedata import category
46

57
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
68
from torchvision.prototype import features
79
from torchvision.prototype.datasets.utils import (
8-
Dataset,
9-
DatasetConfig,
10-
DatasetInfo,
10+
Dataset2,
1111
OnlineResource,
1212
GDriveResource,
1313
)
@@ -17,6 +17,11 @@
1717
)
1818
from torchvision.prototype.features import Label
1919

20+
from .._api import register_dataset, register_info
21+
22+
23+
NAME = "pcam"
24+
2025

2126
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
2227
def __init__(
@@ -40,15 +45,25 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
4045
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
4146

4247

43-
class PCAM(Dataset):
44-
def _make_info(self) -> DatasetInfo:
45-
return DatasetInfo(
46-
"pcam",
47-
homepage="https://github.com/basveeling/pcam",
48-
categories=2,
49-
valid_options=dict(split=("train", "test", "val")),
50-
dependencies=["h5py"],
51-
)
48+
@register_info(NAME)
49+
def _info() -> Dict[str, Any]:
50+
return dict(categories=["0", "1"])
51+
52+
53+
@register_dataset(NAME)
54+
class PCAM(Dataset2):
55+
# TODO write proper docstring
56+
"""PCAM Dataset
57+
58+
homepage="https://github.com/basveeling/pcam"
59+
"""
60+
61+
def __init__(
62+
self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False
63+
) -> None:
64+
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
65+
self._categories = _info()["categories"]
66+
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",))
5267

5368
_RESOURCES = {
5469
"train": (
@@ -89,23 +104,21 @@ def _make_info(self) -> DatasetInfo:
89104
),
90105
}
91106

92-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
107+
def _resources(self) -> List[OnlineResource]:
93108
return [ # = [images resource, targets resource]
94109
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
95-
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
110+
for file_name, gdrive_id, sha256 in self._RESOURCES[self._split]
96111
]
97112

98113
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
99114
image, target = data # They're both numpy arrays at this point
100115

101116
return {
102117
"image": features.Image(image.transpose(2, 0, 1)),
103-
"label": Label(target.item()),
118+
"label": Label(target.item(), categories=self._categories),
104119
}
105120

106-
def _make_datapipe(
107-
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
108-
) -> IterDataPipe[Dict[str, Any]]:
121+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
109122

110123
images_dp, targets_dp = resource_dps
111124

@@ -116,3 +129,6 @@ def _make_datapipe(
116129
dp = hint_shuffling(dp)
117130
dp = hint_sharding(dp)
118131
return Mapper(dp, self._prepare_sample)
132+
133+
def __len__(self):
134+
return 262_144 if self._split == "train" else 32_768

0 commit comments

Comments
 (0)