Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Add ImageFolderDatasource #24641

Merged
merged 28 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f6eb66e
Add files
bveeramani May 10, 2022
8fb8d75
Add files
bveeramani May 11, 2022
083c9bc
Fix stuff
bveeramani May 11, 2022
4f0b8ce
Rename file
bveeramani May 11, 2022
cd6e25e
Rename file
bveeramani May 11, 2022
98b07f7
Update image_folder_datasource.py
bveeramani May 11, 2022
6f7b2eb
Update docs
bveeramani May 11, 2022
db05d35
Update file_meta_provider.py
bveeramani May 11, 2022
52ae3c8
Update image_folder_datasource.py
bveeramani May 11, 2022
813f5de
Update Makefile
bveeramani May 11, 2022
d68fe1d
Update python/ray/data/datasource/image_folder_datasource.py
bveeramani May 18, 2022
45989c2
Re-add warning
bveeramani May 18, 2022
cd079d9
Merge branch 'image-datasource' of https://github.com/bveeramani/ray …
bveeramani May 18, 2022
e4a6222
Merge branch 'master' into pr/24641
bveeramani Jun 7, 2022
d7dcd0e
Merge branch 'master' into image-datasource
bveeramani Jul 14, 2022
2f95235
Update implementation
bveeramani Jul 14, 2022
3c9e3c0
Add read API
bveeramani Jul 15, 2022
e312403
Merge branch 'master' into image-datasource
bveeramani Jul 15, 2022
cdaca59
Fix error in documentation
bveeramani Jul 15, 2022
7c0a317
Update documentation and add test
bveeramani Jul 15, 2022
e07d1f8
Remove `target` column
bveeramani Jul 15, 2022
298c144
Change error type from `ValueError` to `ImportError`
bveeramani Jul 15, 2022
dc9c6e2
Remove `read_image_folder`
bveeramani Jul 15, 2022
3d8e005
Merge branch 'master' into image-datasource
bveeramani Jul 15, 2022
1d4d920
Add API annotation
bveeramani Jul 15, 2022
0f5e089
Add missing `DeveloperAPI` import
bveeramani Jul 15, 2022
18dd566
Skip doctests
bveeramani Jul 15, 2022
70a47d7
update
richardliaw Jul 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/data/package-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Built-in Datasources
.. autoclass:: ray.data.datasource.FileBasedDatasource
:members:

.. autoclass:: ray.data.datasource.ImageFolderDatasource
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: ray.data.datasource.JSONDatasource
:members:

Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FileMetadataProvider,
ParquetMetadataProvider,
)
from ray.data.datasource.image_folder_datasource import ImageFolderDatasource
from ray.data.datasource.json_datasource import JSONDatasource
from ray.data.datasource.numpy_datasource import NumpyDatasource
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
Expand Down Expand Up @@ -52,6 +53,7 @@
"FileBasedDatasource",
"FileExtensionFilter",
"FileMetadataProvider",
"ImageFolderDatasource",
"JSONDatasource",
"NumpyDatasource",
"ParquetBaseDatasource",
Expand Down
144 changes: 144 additions & 0 deletions python/ray/data/datasource/image_folder_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pathlib
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
from ray.data.datasource.binary_datasource import BinaryDatasource
from ray.data.datasource.datasource import Reader
from ray.data.datasource.file_based_datasource import (
_resolve_paths_and_filesystem,
FileExtensionFilter,
)
from ray.data.datasource.partitioning import PathPartitionFilter
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
import pyarrow
from ray.data.block import T

IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "tiff", "bmp", "gif"]


@DeveloperAPI
class ImageFolderDatasource(BinaryDatasource):
"""A datasource that lets you read datasets like `ImageNet <https://www.image-net.org/>`_.

This datasource works with any dataset where images are arranged in this way:

.. code-block::

root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

bveeramani marked this conversation as resolved.
Show resolved Hide resolved
Datasets read with this datasource contain two columns: ``'image'`` and ``'label'``.

* The ``'image'`` column is of type
:py:class:`~ray.air.util.tensor_extensions.pandas.TensorDtype` and contains
tensors of shape :math:`(H, W, C)`.
* The ``'label'`` column contains strings representing class names (e.g., 'cat').

Examples:
>>> import ray
>>> from ray.data.datasource import ImageFolderDatasource
>>>
>>> ds = ray.data.read_datasource( # doctest: +SKIP
... ImageFolderDatasource(),
... paths=["/data/imagenet/train"]
... )
>>>
>>> sample = ds.take(1)[0] # doctest: +SKIP
>>> sample["image"].to_numpy().shape # doctest: +SKIP
(469, 387, 3)
>>> sample["label"] # doctest: +SKIP
'n01443537'

To convert class labels to integer-valued targets, use
:py:class:`~ray.data.preprocessors.OrdinalEncoder`.

>>> import ray
>>> from ray.data.preprocessors import OrdinalEncoder
>>>
>>> ds = ray.data.read_datasource( # doctest: +SKIP
... ImageFolderDatasource(),
... paths=["/data/imagenet/train"]
... )
>>> oe = OrdinalEncoder(columns=["label"]) # doctest: +SKIP
>>>
>>> ds = oe.fit_transform(ds) # doctest: +SKIP
>>>
>>> sample = ds.take(1)[0] # doctest: +SKIP
>>> sample["label"] # doctest: +SKIP
71
""" # noqa: E501

def create_reader(
self,
paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
partition_filter: PathPartitionFilter = None,
**kwargs,
) -> "Reader[T]":
if len(paths) != 1:
raise ValueError(
"`ImageFolderDatasource` expects 1 path representing the dataset "
f"root, but it got {len(paths)} paths instead. To fix this "
"error, pass in a single-element list containing the dataset root "
'(for example, `paths=["s3://imagenet/train"]`)'
)

try:
import imageio # noqa: F401
except ImportError:
raise ImportError(
"`ImageFolderDatasource` depends on 'imageio', but 'imageio' couldn't "
"be imported. You can install 'imageio' by running "
"`pip install imageio`."
)

if partition_filter is None:
partition_filter = FileExtensionFilter(file_extensions=IMAGE_EXTENSIONS)

# We call `_resolve_paths_and_filesystem` so that the dataset root is formatted
# in the same way as the paths passed to `_get_class_from_path`.
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
self.root = paths[0]

return super().create_reader(
paths=paths,
filesystem=filesystem,
partition_filter=partition_filter,
**kwargs,
)

def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
import imageio as iio
import pandas as pd
from ray.data.extensions import TensorArray

records = super()._read_file(f, path, include_paths=True)
assert len(records) == 1
path, data = records[0]

image = iio.imread(data)
label = _get_class_from_path(path, self.root)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any docs / past discussion about this part? Basically we're assuming we get the label based on user file path, which has to be structured in certain way in order to get the correct one without knobs needed to pass in custom label file or join ?

For example, if i read a s3 bucket with filenames of "dog.jpg", "dog_2.jpg" my dataloader will end up getting these string values by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we're assuming we get the label based on user file path, which has to be structured in certain way in order to get the correct one without knobs needed to pass in custom label file or join ?

Yeah, that's right. The datasource assumes that the layout is structured in the same way as ImageNet. The functionality of the datasource is based on that of TorchVision's ImageFolder.

For example, if i read a s3 bucket with filenames of "dog.jpg", "dog_2.jpg" my dataloader will end up getting these string values by default.

Yeah, you're right. We don't validate that the label corresponds to a directory. In this case, we could raise an error stating that the folder isn't structured correctly.

Alternatively, if images aren't stored in a directory, we could set the label to None.

If images aren't stored in a sub-directory, then the image's label will be set to `None`.

.. code-block::

    root/dog/xxx.png  # Label is 'dog'
    root/123.jpg.     # Label is `None`


return pd.DataFrame(
{
"image": TensorArray([np.array(image)]),
"label": [label],
}
)


def _get_class_from_path(path: str, root: str) -> str:
# The class is the name of the first directory after the root. For example, if
# the root is "/data/imagenet/train" and the path is
# "/data/imagenet/train/n01443537/images/n01443537_0.JPEG", then the class is
# "n01443537".
path, root = pathlib.PurePath(path), pathlib.PurePath(root)
assert root in path.parents
return path.parts[len(root.parts) :][0]
Binary file added python/ray/data/tests/image-folder/cat/123.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added python/ray/data/tests/image-folder/cat/foo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Binary file added python/ray/data/tests/image-folder/dog/xxx.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
57 changes: 57 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DefaultParquetMetadataProvider,
DummyOutputDatasource,
FastFileMetadataProvider,
ImageFolderDatasource,
PartitionStyle,
PathPartitionEncoder,
PathPartitionFilter,
Expand All @@ -41,6 +42,8 @@
_SerializedPiece,
_deserialize_pieces_with_retry,
)
from ray.data.extensions import TensorDtype
from ray.data.preprocessors import BatchMapper
from ray.data.tests.conftest import * # noqa
from ray.data.tests.mock_http_server import * # noqa
from ray.tests.conftest import * # noqa
Expand Down Expand Up @@ -2729,6 +2732,60 @@ def test_torch_datasource_value_error(ray_start_regular_shared, local_path):
)


def test_image_folder_datasource(ray_start_regular_shared):
root = os.path.join(os.path.dirname(__file__), "image-folder")
ds = ray.data.read_datasource(ImageFolderDatasource(), paths=[root])

assert ds.count() == 3

df = ds.to_pandas()
assert sorted(df["label"]) == ["cat", "cat", "dog"]
assert type(df["image"].dtype) is TensorDtype
assert all(tensor.to_numpy().shape == (32, 32, 3) for tensor in df["image"])


def test_image_folder_datasource_raises_value_error(ray_start_regular_shared):
# `ImageFolderDatasource` should raise an error if more than one path is passed.
with pytest.raises(ValueError):
ray.data.read_datasource(
ImageFolderDatasource(), paths=["imagenet/train", "imagenet/test"]
)


def test_image_folder_datasource_e2e(ray_start_regular_shared):
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.train.torch import to_air_checkpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor

from torchvision import transforms
from torchvision.models import resnet18

root = os.path.join(os.path.dirname(__file__), "image-folder")
dataset = ray.data.read_datasource(ImageFolderDatasource(), paths=[root])

def preprocess(df):
# We convert the `TensorArrayElement` to a NumPy array because `ToTensor`
# expects a NumPy array or PIL image. `ToTensor` is necessary because Torch
# expects images to have shape (C, H, W), and `ToTensor` changes the shape of
# the data from (H, W, C) to (C, H, W).
preprocess = transforms.Compose(
[
lambda ray_tensor: ray_tensor.to_numpy(),
transforms.ToTensor(),
]
)
df["image"] = TensorArray([preprocess(image) for image in df["image"]])
return df

preprocessor = BatchMapper(preprocess)

model = resnet18(pretrained=True)
checkpoint = to_air_checkpoint(model=model, preprocessor=preprocessor)

predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
predictor.predict(dataset, feature_columns=["image"])


# NOTE: The last test using the shared ray_start_regular_shared cluster must use the
# shutdown_only fixture so the shared cluster is shut down, otherwise the below
# test_write_datasource_ray_remote_args test, which uses a cluster_utils cluster, will
Expand Down