Skip to content

Commit

Permalink
[Datasets] Add ImageFolderDatasource (ray-project#24641)
Browse files Browse the repository at this point in the history
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
  • Loading branch information
3 people authored and xwjiang2010 committed Jul 19, 2022
1 parent ae67ac5 commit 2005c47
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/source/data/package-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Built-in Datasources
.. autoclass:: ray.data.datasource.FileBasedDatasource
:members:

.. autoclass:: ray.data.datasource.ImageFolderDatasource

.. autoclass:: ray.data.datasource.JSONDatasource
:members:

Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FileMetadataProvider,
ParquetMetadataProvider,
)
from ray.data.datasource.image_folder_datasource import ImageFolderDatasource
from ray.data.datasource.json_datasource import JSONDatasource
from ray.data.datasource.numpy_datasource import NumpyDatasource
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
Expand Down Expand Up @@ -52,6 +53,7 @@
"FileBasedDatasource",
"FileExtensionFilter",
"FileMetadataProvider",
"ImageFolderDatasource",
"JSONDatasource",
"NumpyDatasource",
"ParquetBaseDatasource",
Expand Down
144 changes: 144 additions & 0 deletions python/ray/data/datasource/image_folder_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pathlib
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
from ray.data.datasource.binary_datasource import BinaryDatasource
from ray.data.datasource.datasource import Reader
from ray.data.datasource.file_based_datasource import (
_resolve_paths_and_filesystem,
FileExtensionFilter,
)
from ray.data.datasource.partitioning import PathPartitionFilter
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
import pyarrow
from ray.data.block import T

IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "tiff", "bmp", "gif"]


@DeveloperAPI
class ImageFolderDatasource(BinaryDatasource):
"""A datasource that lets you read datasets like `ImageNet <https://www.image-net.org/>`_.
This datasource works with any dataset where images are arranged in this way:
.. code-block::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Datasets read with this datasource contain two columns: ``'image'`` and ``'label'``.
* The ``'image'`` column is of type
:py:class:`~ray.air.util.tensor_extensions.pandas.TensorDtype` and contains
tensors of shape :math:`(H, W, C)`.
* The ``'label'`` column contains strings representing class names (e.g., 'cat').
Examples:
>>> import ray
>>> from ray.data.datasource import ImageFolderDatasource
>>>
>>> ds = ray.data.read_datasource( # doctest: +SKIP
... ImageFolderDatasource(),
... paths=["/data/imagenet/train"]
... )
>>>
>>> sample = ds.take(1)[0] # doctest: +SKIP
>>> sample["image"].to_numpy().shape # doctest: +SKIP
(469, 387, 3)
>>> sample["label"] # doctest: +SKIP
'n01443537'
To convert class labels to integer-valued targets, use
:py:class:`~ray.data.preprocessors.OrdinalEncoder`.
>>> import ray
>>> from ray.data.preprocessors import OrdinalEncoder
>>>
>>> ds = ray.data.read_datasource( # doctest: +SKIP
... ImageFolderDatasource(),
... paths=["/data/imagenet/train"]
... )
>>> oe = OrdinalEncoder(columns=["label"]) # doctest: +SKIP
>>>
>>> ds = oe.fit_transform(ds) # doctest: +SKIP
>>>
>>> sample = ds.take(1)[0] # doctest: +SKIP
>>> sample["label"] # doctest: +SKIP
71
""" # noqa: E501

def create_reader(
self,
paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
partition_filter: PathPartitionFilter = None,
**kwargs,
) -> "Reader[T]":
if len(paths) != 1:
raise ValueError(
"`ImageFolderDatasource` expects 1 path representing the dataset "
f"root, but it got {len(paths)} paths instead. To fix this "
"error, pass in a single-element list containing the dataset root "
'(for example, `paths=["s3://imagenet/train"]`)'
)

try:
import imageio # noqa: F401
except ImportError:
raise ImportError(
"`ImageFolderDatasource` depends on 'imageio', but 'imageio' couldn't "
"be imported. You can install 'imageio' by running "
"`pip install imageio`."
)

if partition_filter is None:
partition_filter = FileExtensionFilter(file_extensions=IMAGE_EXTENSIONS)

# We call `_resolve_paths_and_filesystem` so that the dataset root is formatted
# in the same way as the paths passed to `_get_class_from_path`.
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
self.root = paths[0]

return super().create_reader(
paths=paths,
filesystem=filesystem,
partition_filter=partition_filter,
**kwargs,
)

def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
import imageio as iio
import pandas as pd
from ray.data.extensions import TensorArray

records = super()._read_file(f, path, include_paths=True)
assert len(records) == 1
path, data = records[0]

image = iio.imread(data)
label = _get_class_from_path(path, self.root)

return pd.DataFrame(
{
"image": TensorArray([np.array(image)]),
"label": [label],
}
)


def _get_class_from_path(path: str, root: str) -> str:
# The class is the name of the first directory after the root. For example, if
# the root is "/data/imagenet/train" and the path is
# "/data/imagenet/train/n01443537/images/n01443537_0.JPEG", then the class is
# "n01443537".
path, root = pathlib.PurePath(path), pathlib.PurePath(root)
assert root in path.parents
return path.parts[len(root.parts) :][0]
Binary file added python/ray/data/tests/image-folder/cat/123.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added python/ray/data/tests/image-folder/cat/foo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Binary file added python/ray/data/tests/image-folder/dog/xxx.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
57 changes: 57 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DefaultParquetMetadataProvider,
DummyOutputDatasource,
FastFileMetadataProvider,
ImageFolderDatasource,
PartitionStyle,
PathPartitionEncoder,
PathPartitionFilter,
Expand All @@ -41,6 +42,8 @@
_SerializedPiece,
_deserialize_pieces_with_retry,
)
from ray.data.extensions import TensorDtype
from ray.data.preprocessors import BatchMapper
from ray.data.tests.conftest import * # noqa
from ray.data.tests.mock_http_server import * # noqa
from ray.tests.conftest import * # noqa
Expand Down Expand Up @@ -2729,6 +2732,60 @@ def test_torch_datasource_value_error(ray_start_regular_shared, local_path):
)


def test_image_folder_datasource(ray_start_regular_shared):
root = os.path.join(os.path.dirname(__file__), "image-folder")
ds = ray.data.read_datasource(ImageFolderDatasource(), paths=[root])

assert ds.count() == 3

df = ds.to_pandas()
assert sorted(df["label"]) == ["cat", "cat", "dog"]
assert type(df["image"].dtype) is TensorDtype
assert all(tensor.to_numpy().shape == (32, 32, 3) for tensor in df["image"])


def test_image_folder_datasource_raises_value_error(ray_start_regular_shared):
# `ImageFolderDatasource` should raise an error if more than one path is passed.
with pytest.raises(ValueError):
ray.data.read_datasource(
ImageFolderDatasource(), paths=["imagenet/train", "imagenet/test"]
)


def test_image_folder_datasource_e2e(ray_start_regular_shared):
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.train.torch import to_air_checkpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor

from torchvision import transforms
from torchvision.models import resnet18

root = os.path.join(os.path.dirname(__file__), "image-folder")
dataset = ray.data.read_datasource(ImageFolderDatasource(), paths=[root])

def preprocess(df):
# We convert the `TensorArrayElement` to a NumPy array because `ToTensor`
# expects a NumPy array or PIL image. `ToTensor` is necessary because Torch
# expects images to have shape (C, H, W), and `ToTensor` changes the shape of
# the data from (H, W, C) to (C, H, W).
preprocess = transforms.Compose(
[
lambda ray_tensor: ray_tensor.to_numpy(),
transforms.ToTensor(),
]
)
df["image"] = TensorArray([preprocess(image) for image in df["image"]])
return df

preprocessor = BatchMapper(preprocess)

model = resnet18(pretrained=True)
checkpoint = to_air_checkpoint(model=model, preprocessor=preprocessor)

predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
predictor.predict(dataset, feature_columns=["image"])


# NOTE: The last test using the shared ray_start_regular_shared cluster must use the
# shutdown_only fixture so the shared cluster is shut down, otherwise the below
# test_write_datasource_ray_remote_args test, which uses a cluster_utils cluster, will
Expand Down

0 comments on commit 2005c47

Please sign in to comment.