Skip to content

add CLEVR dataset #5130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,5 +2325,37 @@ def inject_fake_data(self, tmpdir: str, config):
return total_number_of_examples


class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CLEVRClassification
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))

def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"

images_folder = data_folder / "images"
image_files = datasets_utils.create_image_folder(
images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5
)

scenes_folder = data_folder / "scenes"
scenes_folder.mkdir()
if config["split"] != "test":
with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file:
json.dump(
dict(
info=dict(),
scenes=[
dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(10, ())))
for image_file in image_files
],
),
file,
)

return len(image_files)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes
from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData
Expand Down Expand Up @@ -85,4 +86,5 @@
"DTD",
"FER2013",
"GTSRB",
"CLEVRClassification",
)
88 changes: 88 additions & 0 deletions torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import pathlib
from typing import Any, Callable, Optional, Tuple, List
from urllib.parse import urlparse

from PIL import Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class CLEVRClassification(VisionDataset):
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.

The number of objects in a scene are used as label.

Args:
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
dataset is already downloaded, it is not downloaded again.
"""

_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
_MD5 = "b11922020e72d0cd9154779b2d3d07d2"

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / "clevr"
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem

if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))

self._labels: List[Optional[int]]
if self._split != "test":
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
content = json.load(file)
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
self._labels = [num_objects[image_file.name] for image_file in self._image_files]
else:
self._labels = [None] * len(self._image_files)

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file = self._image_files[idx]
label = self._labels[idx]

image = Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def _check_exists(self) -> bool:
return self._data_folder.exists() and self._data_folder.is_dir()

def _download(self) -> None:
if self._check_exists():
return

download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)

def extra_repr(self) -> str:
return f"split={self._split}"
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .dtd import DTD
from .fer2013 import FER2013
Expand Down
110 changes: 110 additions & 0 deletions torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
hint_shuffling,
path_comparator,
path_accessor,
getitem,
)
from torchvision.prototype.features import Label


class CLEVR(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"clevr",
type=DatasetType.IMAGE,
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
valid_options=dict(split=("train", "val", "test")),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
)
return [archive]

def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parent.name == "scenes":
return 1
else:
return None

def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
key, _ = data
return key == "scenes"

def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]:
return data, None

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data

return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)

images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp)

if config.split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)

dp = IterKeyZipper(
images_dp,
scenes_dp,
key_fn=path_accessor("name"),
ref_key_fn=getitem("image_filename"),
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
dp = Mapper(images_dp, self._add_empty_anns)

return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
10 changes: 8 additions & 2 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __iter__(self) -> Iterator[Tuple[int, D]]:
yield from enumerate(self.datapipe, self.start)


def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any:
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
for item in items:
obj = obj[item]
return obj
Expand All @@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]:
return functools.partial(_getitem_closure, items=items)


def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
for attr in attrs:
obj = getattr(obj, attr)
return obj
Comment on lines +121 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand why this change is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables chained getattr calls. In turn, this allows me to use this filter function

images_dp = Filter(images_dp, path_comparator("parent.name", config.split))

rather than writing a custom function that extracts the name of the parent folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it up to you to decide, as the entire logic of path_comparator and path_accessor is honestly too complex for me to handle ATM.

But as rule of thumb, I tend to avoid changes to helper functions when they just only address one single use-case.

Copy link
Collaborator Author

@pmeier pmeier Jan 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as the entire logic of path_comparator and path_accessor is honestly too complex for me to handle ATM.

The complexity stems from the fact that we cannot use lambdas or local function together with the datapipes, since they cannot be serialized safely. Quick explanation:

  • path_accessor: When using datapipes you often have path-handle-tuples for files. path_accessor takes such a tuple, turns the the first item into a pathlib.Path and accesses it according to the input. So path_accessor("parent.name") is equivalent to

    def extract_folder_name(data):
        path = pathlib.Path(data[0])
        return path.parent.name

    This is useful when you want to merge two datapipes and need a function that generates merge keys.

  • path_comparator: This is one layer on top of path_accessor by providing an equality comparison for the extract path information. For example Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) will select all files where the file name matches the second argument. (This is what I used to refactor your suggestion about the regex).

But as rule of thumb, I tend to avoid changes to helper functions when they just only address one single use-case.

It is true, that this currently only addresses this, but I already needed to make the same changes to getitem a while back. So I'm guessing if I don't do it now, I'll have to do it in the future anyway, but then I also have to remember that I need to go back to here and fix this.



def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, getattr(path, name))
return cast(D, _getattr_closure(path, attrs=name.split(".")))


def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
Expand Down