Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Move Dataloader Wrappers to OSS #455

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions classy_vision/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,30 @@ def register_dataset_cls(cls):
from .classy_imagenet import ImageNetDataset # isort:skip
from .classy_kinetics400 import Kinetics400Dataset # isort:skip
from .classy_synthetic_image import SyntheticImageDataset # isort:skip
from .classy_synthetic_image_streaming import ( # isort:skip
SyntheticImageStreamingDataset, # isort:skip
) # isort:skip
from .classy_synthetic_video import SyntheticVideoDataset # isort:skip
from .classy_ucf101 import UCF101Dataset # isort:skip
from .classy_video_dataset import ClassyVideoDataset # isort:skip
from .dataloader_limit_wrapper import DataloaderLimitWrapper # isort:skip
from .dataloader_skip_none_wrapper import DataloaderSkipNoneWrapper # isort:skip
from .dataloader_wrapper import DataloaderWrapper # isort:skip
from .image_path_dataset import ImagePathDataset # isort:skip

__all__ = [
"CIFARDataset",
"ClassyDataset",
"ClassyVideoDataset",
"DataloaderLimitWrapper",
"DataloaderSkipNoneWrapper",
"DataloaderWrapper",
"HMDB51Dataset",
"ImageNetDataset",
"ImagePathDataset",
"Kinetics400Dataset",
"SyntheticImageDataset",
"SyntheticImageStreamingDataset",
"SyntheticVideoDataset",
"UCF101Dataset",
"build_dataset",
Expand Down
88 changes: 88 additions & 0 deletions classy_vision/dataset/classy_synthetic_image_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torchvision.transforms as transforms
from classy_vision.dataset import register_dataset
from classy_vision.dataset.classy_dataset import ClassyDataset
from classy_vision.dataset.core import RandomImageBinaryClassDataset
from classy_vision.dataset.dataloader_limit_wrapper import DataloaderLimitWrapper
from classy_vision.dataset.transforms.util import (
ImagenetConstants,
build_field_transform_default_imagenet,
)


@register_dataset("synthetic_image_streaming")
class SyntheticImageStreamingDataset(ClassyDataset):
"""
Synthetic image dataset that behaves like a streaming dataset.

Requires a "num_samples" argument which decides the number of samples in the
phase. Also takes an optional "length" input which sets the length of the
dataset.
"""

def __init__(
self,
batchsize_per_replica,
shuffle,
transform,
num_samples,
crop_size,
class_ratio,
seed,
length=None,
):
if length is None:
# If length not provided, set to be same as num_samples
length = num_samples

dataset = RandomImageBinaryClassDataset(crop_size, class_ratio, length, seed)
super().__init__(
dataset, batchsize_per_replica, shuffle, transform, num_samples
)

@classmethod
def from_config(cls, config):
assert all(key in config for key in ["crop_size", "class_ratio", "seed"])
length = config.get("length")
crop_size = config["crop_size"]
class_ratio = config["class_ratio"]
seed = config["seed"]
(
transform_config,
batchsize_per_replica,
shuffle,
num_samples,
) = cls.parse_config(config)
default_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=ImagenetConstants.MEAN, std=ImagenetConstants.STD
),
]
)
transform = build_field_transform_default_imagenet(
transform_config, default_transform=default_transform
)
return cls(
batchsize_per_replica,
shuffle,
transform,
num_samples,
crop_size,
class_ratio,
seed,
length=length,
)

def iterator(self, *args, **kwargs):
return DataloaderLimitWrapper(
super().iterator(*args, **kwargs),
self.num_samples // self.get_global_batchsize(),
)
77 changes: 77 additions & 0 deletions classy_vision/dataset/dataloader_limit_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Iterable, Iterator

from .dataloader_wrapper import DataloaderWrapper


class DataloaderLimitWrapper(DataloaderWrapper):
"""
Dataloader which wraps another dataloader and only returns a limited
number of items.

This is useful for Iterable datasets where the length of the datasets isn't known.
Such datasets can wrap their returned iterators with this class. See
:func:`SyntheticImageStreamingDataset.iterator` for an example.

Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(
self, dataloader: Iterable, limit: int, wrap_around: bool = True
) -> None:
"""Constructor for DataloaderLimitWrapper.

Args:
dataloader: The dataloader to wrap around
limit: Specify the number of calls to the underlying dataloader. The wrapper
will raise a `StopIteration` after `limit` calls.
wrap_around: Whether to wrap around the original datatloader if the
dataloader is exhausted before `limit` calls.
Raises:
RuntimeError: If `wrap_around` is set to `False` and the underlying
dataloader is exhausted before `limit` calls.
"""
super().__init__(dataloader)
# we use self.__dict__ to set the attributes since the __setattr__ method
# is overridden
attributes = {"limit": limit, "wrap_around": wrap_around, "_count": None}
self.__dict__.update(attributes)

def __iter__(self) -> Iterator[Any]:
self._iter = iter(self.dataloader)
self._count = 0
return self

def __next__(self) -> Any:
if self._count >= self.limit:
raise StopIteration
self._count += 1
try:
return next(self._iter)
except StopIteration:
if self.wrap_around:
# create a new iterator to load data from the beginning
logging.info(
f"Wrapping around after {self._count} calls. Limit: {self.limit}"
)
try:
self._iter = iter(self.dataloader)
return next(self._iter)
except StopIteration:
raise RuntimeError(
"Looks like the dataset is empty, "
"have you configured it properly?"
)
else:
raise RuntimeError(
f"StopIteration raised before {self.limit} items were returned"
)

def __len__(self) -> int:
return self.limit
33 changes: 33 additions & 0 deletions classy_vision/dataset/dataloader_skip_none_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable, Iterator

from .dataloader_wrapper import DataloaderWrapper


class DataloaderSkipNoneWrapper(DataloaderWrapper):
"""
Dataloader which wraps another dataloader and skip `None` batch data.

Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(self, dataloader: Iterable) -> None:
super().__init__(dataloader)

def __iter__(self) -> Iterator[Any]:
self._iter = iter(self.dataloader)
return self

def __next__(self) -> Any:
# we may get `None` batch data when all the images/videos in the batch
# are corrupted. In such case, we keep getting the next batch until
# meeting a good batch.
next_batch = None
while next_batch is None:
next_batch = next(self._iter)
return next_batch
47 changes: 47 additions & 0 deletions classy_vision/dataset/dataloader_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Iterable, Iterator


class DataloaderWrapper(ABC):
"""
Abstract class representing dataloader which wraps another dataloader.

Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(self, dataloader: Iterable) -> None:
# we use self.__dict__ to set the attributes since the __setattr__ method
# is overridden
attributes = {"dataloader": dataloader, "_iter": None}
self.__dict__.update(attributes)

@abstractmethod
def __iter__(self) -> Iterator[Any]:
pass

@abstractmethod
def __next__(self) -> Any:
pass

def __getattr__(self, attr) -> Any:
"""
Pass the getattr call to the wrapped dataloader
"""
if attr in self.__dict__:
return self.__dict__[attr]
return getattr(self.dataloader, attr)

def __setattr__(self, attr, value) -> None:
"""
Pass the setattr call to the wrapped dataloader
"""
if attr in self.__dict__:
self.__dict__[attr] = value
else:
setattr(self.dataloader, attr, value)
53 changes: 53 additions & 0 deletions test/dataset_dataloader_limit_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from test.generic.config_utils import get_test_task_config

from classy_vision.tasks import build_task


class TestDataloaderLimitWrapper(unittest.TestCase):
def _test_number_of_batches(self, data_iterator, expected_batches):
num_batches = 0
for _ in data_iterator:
num_batches += 1
self.assertEqual(num_batches, expected_batches)

def test_streaming_dataset(self):
"""
Test that streaming datasets return the correct number of batches, and that
the length is also calculated correctly.
"""
config = get_test_task_config()
dataset_config = {
"name": "synthetic_image_streaming",
"split": "train",
"crop_size": 224,
"class_ratio": 0.5,
"num_samples": 2000,
"length": 4000,
"seed": 0,
"batchsize_per_replica": 32,
"use_shuffle": True,
}
expected_batches = 62
config["dataset"]["train"] = dataset_config
task = build_task(config)
task.prepare()
task.advance_phase()
# test that the number of batches expected is correct
self.assertEqual(task.num_batches_per_phase, expected_batches)

# test that the data iterator returns the expected number of batches
data_iterator = task.get_data_iterator()
self._test_number_of_batches(data_iterator, expected_batches)

# test that the dataloader can be rebuilt from the dataset inside it
task._recreate_data_loader_from_dataset()
task.create_data_iterator()
data_iterator = task.get_data_iterator()
self._test_number_of_batches(data_iterator, expected_batches)