-
Notifications
You must be signed in to change notification settings - Fork 7.1k
add CLEVR dataset #5130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add CLEVR dataset #5130
Changes from all commits
a83850d
2f178e6
6dd6d91
68a762e
739ac29
ac97d11
339f1ba
75d0887
35f464c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import pathlib | ||
from typing import Any, Callable, Optional, Tuple, List | ||
from urllib.parse import urlparse | ||
|
||
from PIL import Image | ||
|
||
from .utils import download_and_extract_archive, verify_str_arg | ||
from .vision import VisionDataset | ||
|
||
|
||
class CLEVRClassification(VisionDataset): | ||
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset. | ||
|
||
The number of objects in a scene are used as label. | ||
|
||
Args: | ||
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is | ||
set to True. | ||
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. | ||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed | ||
version. E.g, ``transforms.RandomCrop`` | ||
target_transform (callable, optional): A function/transform that takes in them target and transforms it. | ||
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If | ||
dataset is already downloaded, it is not downloaded again. | ||
""" | ||
|
||
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" | ||
_MD5 = "b11922020e72d0cd9154779b2d3d07d2" | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
split: str = "train", | ||
transform: Optional[Callable] = None, | ||
target_transform: Optional[Callable] = None, | ||
download: bool = True, | ||
) -> None: | ||
self._split = verify_str_arg(split, "split", ("train", "val", "test")) | ||
super().__init__(root, transform=transform, target_transform=target_transform) | ||
self._base_folder = pathlib.Path(self.root) / "clevr" | ||
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem | ||
|
||
if download: | ||
self._download() | ||
|
||
if not self._check_exists(): | ||
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") | ||
|
||
self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*")) | ||
|
||
self._labels: List[Optional[int]] | ||
if self._split != "test": | ||
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file: | ||
content = json.load(file) | ||
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]} | ||
self._labels = [num_objects[image_file.name] for image_file in self._image_files] | ||
else: | ||
self._labels = [None] * len(self._image_files) | ||
|
||
def __len__(self) -> int: | ||
return len(self._image_files) | ||
|
||
def __getitem__(self, idx: int) -> Tuple[Any, Any]: | ||
image_file = self._image_files[idx] | ||
label = self._labels[idx] | ||
|
||
image = Image.open(image_file).convert("RGB") | ||
|
||
if self.transform: | ||
image = self.transform(image) | ||
|
||
if self.target_transform: | ||
label = self.target_transform(label) | ||
|
||
return image, label | ||
|
||
def _check_exists(self) -> bool: | ||
return self._data_folder.exists() and self._data_folder.is_dir() | ||
|
||
def _download(self) -> None: | ||
if self._check_exists(): | ||
return | ||
|
||
download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5) | ||
|
||
def extra_repr(self) -> str: | ||
return f"split={self._split}" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import functools | ||
import io | ||
import pathlib | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher | ||
from torchvision.prototype.datasets.utils import ( | ||
Dataset, | ||
DatasetConfig, | ||
DatasetInfo, | ||
HttpResource, | ||
OnlineResource, | ||
DatasetType, | ||
) | ||
from torchvision.prototype.datasets.utils._internal import ( | ||
INFINITE_BUFFER_SIZE, | ||
hint_sharding, | ||
hint_shuffling, | ||
path_comparator, | ||
path_accessor, | ||
getitem, | ||
) | ||
from torchvision.prototype.features import Label | ||
|
||
|
||
class CLEVR(Dataset): | ||
def _make_info(self) -> DatasetInfo: | ||
return DatasetInfo( | ||
"clevr", | ||
type=DatasetType.IMAGE, | ||
homepage="https://cs.stanford.edu/people/jcjohns/clevr/", | ||
valid_options=dict(split=("train", "val", "test")), | ||
) | ||
|
||
def resources(self, config: DatasetConfig) -> List[OnlineResource]: | ||
archive = HttpResource( | ||
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", | ||
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", | ||
) | ||
return [archive] | ||
|
||
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: | ||
path = pathlib.Path(data[0]) | ||
if path.parents[1].name == "images": | ||
return 0 | ||
elif path.parent.name == "scenes": | ||
return 1 | ||
else: | ||
return None | ||
|
||
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: | ||
key, _ = data | ||
return key == "scenes" | ||
|
||
def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]: | ||
return data, None | ||
|
||
def _collate_and_decode_sample( | ||
self, | ||
data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]], | ||
*, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> Dict[str, Any]: | ||
image_data, scenes_data = data | ||
path, buffer = image_data | ||
|
||
return dict( | ||
path=path, | ||
image=decoder(buffer) if decoder else buffer, | ||
label=Label(len(scenes_data["objects"])) if scenes_data else None, | ||
) | ||
|
||
def _make_datapipe( | ||
self, | ||
resource_dps: List[IterDataPipe], | ||
*, | ||
config: DatasetConfig, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> IterDataPipe[Dict[str, Any]]: | ||
archive_dp = resource_dps[0] | ||
images_dp, scenes_dp = Demultiplexer( | ||
archive_dp, | ||
2, | ||
self._classify_archive, | ||
drop_none=True, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
|
||
images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) | ||
images_dp = hint_sharding(images_dp) | ||
images_dp = hint_shuffling(images_dp) | ||
|
||
if config.split != "test": | ||
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) | ||
scenes_dp = JsonParser(scenes_dp) | ||
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) | ||
scenes_dp = UnBatcher(scenes_dp) | ||
|
||
dp = IterKeyZipper( | ||
images_dp, | ||
scenes_dp, | ||
key_fn=path_accessor("name"), | ||
ref_key_fn=getitem("image_filename"), | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
else: | ||
dp = Mapper(images_dp, self._add_empty_anns) | ||
|
||
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,7 +108,7 @@ def __iter__(self) -> Iterator[Tuple[int, D]]: | |
yield from enumerate(self.datapipe, self.start) | ||
|
||
|
||
def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any: | ||
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any: | ||
for item in items: | ||
obj = obj[item] | ||
return obj | ||
|
@@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]: | |
return functools.partial(_getitem_closure, items=items) | ||
|
||
|
||
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any: | ||
for attr in attrs: | ||
obj = getattr(obj, attr) | ||
return obj | ||
Comment on lines
+121
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you help me understand why this change is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This enables chained images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) rather than writing a custom function that extracts the name of the parent folder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave it up to you to decide, as the entire logic of But as rule of thumb, I tend to avoid changes to helper functions when they just only address one single use-case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The complexity stems from the fact that we cannot use lambdas or local function together with the datapipes, since they cannot be serialized safely. Quick explanation:
It is true, that this currently only addresses this, but I already needed to make the same changes to |
||
|
||
|
||
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: | ||
return cast(D, getattr(path, name)) | ||
return cast(D, _getattr_closure(path, attrs=name.split("."))) | ||
|
||
|
||
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: | ||
|
Uh oh!
There was an error while loading. Please reload this page.