Skip to content

Commit be67431

Browse files
authored
remove decoding from prototype datasets (#5287)
* remove decoder from prototype datasets * remove unused imports * cleanup * fix readme * use OneHotLabel in SEMEION * improve voc implementation * revert unrelated changes * fix semeion mock data
1 parent bfc8510 commit be67431

30 files changed

+463
-791
lines changed

test/builtin_dataset_mocks.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -431,50 +431,52 @@ def caltech256(info, root, config):
431431

432432
@register_mock
433433
def imagenet(info, root, config):
434-
wnids = tuple(info.extra.wnid_to_category.keys())
435-
if config.split == "train":
436-
images_root = root / "ILSVRC2012_img_train"
434+
from scipy.io import savemat
437435

436+
categories = info.categories
437+
wnids = [info.extra.category_to_wnid[category] for category in categories]
438+
if config.split == "train":
438439
num_samples = len(wnids)
440+
archive_name = "ILSVRC2012_img_train.tar"
439441

442+
files = []
440443
for wnid in wnids:
441-
files = create_image_folder(
442-
root=images_root,
444+
create_image_folder(
445+
root=root,
443446
name=wnid,
444447
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
445448
num_examples=1,
446449
)
447-
make_tar(images_root, f"{wnid}.tar", files[0].parent)
450+
files.append(make_tar(root, f"{wnid}.tar"))
448451
elif config.split == "val":
449452
num_samples = 3
450-
files = create_image_folder(
451-
root=root,
452-
name="ILSVRC2012_img_val",
453-
file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG",
454-
num_examples=num_samples,
455-
)
456-
images_root = files[0].parent
457-
else: # config.split == "test"
458-
images_root = root / "ILSVRC2012_img_test_v10102019"
453+
archive_name = "ILSVRC2012_img_val.tar"
454+
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
459455

460-
num_samples = 3
456+
devkit_root = root / "ILSVRC2012_devkit_t12"
457+
data_root = devkit_root / "data"
458+
data_root.mkdir(parents=True)
461459

462-
create_image_folder(
463-
root=images_root,
464-
name="test",
465-
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
466-
num_examples=num_samples,
467-
)
468-
make_tar(root, f"{images_root.name}.tar", images_root)
460+
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
461+
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
462+
file.write(f"{label}\n")
463+
464+
num_children = 0
465+
synsets = [
466+
(idx, wnid, category, "", num_children, [], 0, 0)
467+
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
468+
]
469+
num_children = 1
470+
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
471+
savemat(data_root / "meta.mat", dict(synsets=synsets))
472+
473+
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
474+
else: # config.split == "test"
475+
num_samples = 5
476+
archive_name = "ILSVRC2012_img_test_v10102019.tar"
477+
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
469478

470-
devkit_root = root / "ILSVRC2012_devkit_t12"
471-
devkit_root.mkdir()
472-
data_root = devkit_root / "data"
473-
data_root.mkdir()
474-
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
475-
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
476-
file.write(f"{label}\n")
477-
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
479+
make_tar(root, archive_name, *files)
478480

479481
return num_samples
480482

@@ -666,14 +668,15 @@ def sbd(info, root, config):
666668
@register_mock
667669
def semeion(info, root, config):
668670
num_samples = 3
671+
num_categories = len(info.categories)
669672

670673
images = torch.rand(num_samples, 256)
671-
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
674+
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
672675
with open(root / "semeion.data", "w") as fh:
673676
for image, one_hot_label in zip(images, labels):
674677
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
675678
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
676-
fh.write(f"{image_columns} {labels_columns}\n")
679+
fh.write(f"{image_columns} {labels_columns} \n")
677680

678681
return num_samples
679682

@@ -728,32 +731,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
728731
def _make_detection_ann_file(cls, root, name):
729732
def add_child(parent, name, text=None):
730733
child = ET.SubElement(parent, name)
731-
child.text = text
734+
child.text = str(text)
732735
return child
733736

734737
def add_name(obj, name="dog"):
735738
add_child(obj, "name", name)
736-
return name
737739

738-
def add_bndbox(obj, bndbox=None):
739-
if bndbox is None:
740-
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
740+
def add_size(obj):
741+
obj = add_child(obj, "size")
742+
size = {"width": 0, "height": 0, "depth": 3}
743+
for name, text in size.items():
744+
add_child(obj, name, text)
741745

746+
def add_bndbox(obj):
742747
obj = add_child(obj, "bndbox")
748+
bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4}
743749
for name, text in bndbox.items():
744750
add_child(obj, name, text)
745751

746-
return bndbox
747-
748752
annotation = ET.Element("annotation")
753+
add_size(annotation)
749754
obj = add_child(annotation, "object")
750-
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
755+
add_name(obj)
756+
add_bndbox(obj)
751757

752758
with open(root / name, "wb") as fh:
753759
fh.write(ET.tostring(annotation))
754760

755-
return data
756-
757761
@classmethod
758762
def generate(cls, root, *, year, trainval):
759763
archive_folder = root

test/test_prototype_builtin_datasets.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import functools
12
import io
23
from pathlib import Path
34

45
import pytest
56
import torch
67
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
8+
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
79
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
810
from torch.utils.data.graph import traverse
911
from torchdata.datapipes.iter import IterDataPipe, Shuffler
1012
from torchvision.prototype import transforms, datasets
1113
from torchvision.prototype.utils._internal import sequence_to_str
1214

1315

16+
assert_samples_equal = functools.partial(
17+
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
18+
)
19+
20+
1421
@pytest.fixture
1522
def test_home(mocker, tmp_path):
1623
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
@@ -92,6 +99,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
9299
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
93100
)
94101

102+
@pytest.mark.xfail
95103
@parametrize_dataset_mocks(DATASET_MOCKS)
96104
def test_transformable(self, test_home, dataset_mock, config):
97105
dataset_mock.prepare(test_home, config)
@@ -137,6 +145,17 @@ def scan(graph):
137145
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
138146
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
139147

148+
@parametrize_dataset_mocks(DATASET_MOCKS)
149+
def test_save_load(self, test_home, dataset_mock, config):
150+
dataset_mock.prepare(test_home, config)
151+
dataset = datasets.load(dataset_mock.name, **config)
152+
sample = next(iter(dataset))
153+
154+
with io.BytesIO() as buffer:
155+
torch.save(sample, buffer)
156+
buffer.seek(0)
157+
assert_samples_equal(torch.load(buffer), sample)
158+
140159

141160
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
142161
class TestQMNIST:
@@ -171,5 +190,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
171190
dataset = datasets.load(dataset_mock.name, **config)
172191

173192
for sample in dataset:
174-
label_from_path = int(Path(sample["image_path"]).parent.name)
193+
label_from_path = int(Path(sample["path"]).parent.name)
175194
assert sample["label"] == label_from_path

test/test_prototype_datasets_api.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
66

77

8-
def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
9-
return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs)
8+
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
9+
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)
1010

1111

1212
class TestFrozenMapping:
@@ -176,7 +176,7 @@ def resources(self, config):
176176
# This method is just defined to appease the ABC, but will be overwritten at instantiation
177177
pass
178178

179-
def _make_datapipe(self, resource_dps, *, config, decoder):
179+
def _make_datapipe(self, resource_dps, *, config):
180180
# This method is just defined to appease the ABC, but will be overwritten at instantiation
181181
pass
182182

@@ -229,12 +229,3 @@ def test_resources(self, mocker):
229229

230230
(call_args, _) = dataset._make_datapipe.call_args
231231
assert call_args[0][0] is sentinel
232-
233-
def test_decoder(self):
234-
dataset = self.DatasetMock()
235-
236-
sentinel = object()
237-
dataset.load("", decoder=sentinel)
238-
239-
(_, call_kwargs) = dataset._make_datapipe.call_args
240-
assert call_kwargs["decoder"] is sentinel

test/test_prototype_transforms.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

torchvision/prototype/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"Note that you cannot install it with `pip install torchdata`, since this is another package."
88
) from error
99

10-
from . import decoder, utils
10+
from . import utils
1111
from ._home import home
1212

1313
# Load this last, since some parts depend on the above being loaded first

torchvision/prototype/datasets/_api.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import io
21
import os
3-
from typing import Any, Callable, Dict, List, Optional
2+
from typing import Any, Dict, List
43

5-
import torch
64
from torch.utils.data import IterDataPipe
75
from torchvision.prototype.datasets import home
8-
from torchvision.prototype.datasets.decoder import raw, pil
9-
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
6+
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
107
from torchvision.prototype.utils._internal import add_suggestion
118

129
from . import _builtin
@@ -49,27 +46,15 @@ def info(name: str) -> DatasetInfo:
4946
return find(name).info
5047

5148

52-
DEFAULT_DECODER = object()
53-
54-
DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
55-
DatasetType.RAW: raw,
56-
DatasetType.IMAGE: pil,
57-
}
58-
59-
6049
def load(
6150
name: str,
6251
*,
63-
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
6452
skip_integrity_check: bool = False,
6553
**options: Any,
6654
) -> IterDataPipe[Dict[str, Any]]:
6755
dataset = find(name)
6856

69-
if decoder is DEFAULT_DECODER:
70-
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
71-
7257
config = dataset.info.make_config(**options)
7358
root = os.path.join(home(), dataset.name)
7459

75-
return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
60+
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)

0 commit comments

Comments
 (0)