Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Add ImageFolderDatasource #24641

Merged
merged 28 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f6eb66e
Add files
bveeramani May 10, 2022
8fb8d75
Add files
bveeramani May 11, 2022
083c9bc
Fix stuff
bveeramani May 11, 2022
4f0b8ce
Rename file
bveeramani May 11, 2022
cd6e25e
Rename file
bveeramani May 11, 2022
98b07f7
Update image_folder_datasource.py
bveeramani May 11, 2022
6f7b2eb
Update docs
bveeramani May 11, 2022
db05d35
Update file_meta_provider.py
bveeramani May 11, 2022
52ae3c8
Update image_folder_datasource.py
bveeramani May 11, 2022
813f5de
Update Makefile
bveeramani May 11, 2022
d68fe1d
Update python/ray/data/datasource/image_folder_datasource.py
bveeramani May 18, 2022
45989c2
Re-add warning
bveeramani May 18, 2022
cd079d9
Merge branch 'image-datasource' of https://github.com/bveeramani/ray …
bveeramani May 18, 2022
e4a6222
Merge branch 'master' into pr/24641
bveeramani Jun 7, 2022
d7dcd0e
Merge branch 'master' into image-datasource
bveeramani Jul 14, 2022
2f95235
Update implementation
bveeramani Jul 14, 2022
3c9e3c0
Add read API
bveeramani Jul 15, 2022
e312403
Merge branch 'master' into image-datasource
bveeramani Jul 15, 2022
cdaca59
Fix error in documentation
bveeramani Jul 15, 2022
7c0a317
Update documentation and add test
bveeramani Jul 15, 2022
e07d1f8
Remove `target` column
bveeramani Jul 15, 2022
298c144
Change error type from `ValueError` to `ImportError`
bveeramani Jul 15, 2022
dc9c6e2
Remove `read_image_folder`
bveeramani Jul 15, 2022
3d8e005
Merge branch 'master' into image-datasource
bveeramani Jul 15, 2022
1d4d920
Add API annotation
bveeramani Jul 15, 2022
0f5e089
Add missing `DeveloperAPI` import
bveeramani Jul 15, 2022
18dd566
Skip doctests
bveeramani Jul 15, 2022
70a47d7
update
richardliaw Jul 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/data/package-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ Built-in Datasources
.. autoclass:: ray.data.datasource.FileBasedDatasource
:members:

.. autoclass:: ray.data.datasource.ImageFolderDatasource
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: ray.data.datasource.JSONDatasource
:members:

Expand Down
9 changes: 6 additions & 3 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ def __getitem__(self, key: str) -> Any:
return None
item = col.iloc[0]
try:
# Try to interpret this as a numpy-type value.
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
return item.item()
if item.size == 1:
# Try to interpret this as a numpy-type value.
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
return item.item()
else:
return item
except AttributeError:
# Fallback to the original form.
return item
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
FileMetadataProvider,
ParquetMetadataProvider,
)
from ray.data.datasource.image_folder_datasource import ImageFolderDatasource
from ray.data.datasource.json_datasource import JSONDatasource
from ray.data.datasource.numpy_datasource import NumpyDatasource
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
Expand Down Expand Up @@ -49,6 +50,7 @@
"FastFileMetadataProvider",
"FileBasedDatasource",
"FileMetadataProvider",
"ImageFolderDatasource",
"JSONDatasource",
"NumpyDatasource",
"ParquetBaseDatasource",
Expand Down
135 changes: 135 additions & 0 deletions python/ray/data/datasource/image_folder_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pathlib
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import imageio as iio
import numpy as np
from ray.data.block import Block
from ray.data.datasource.binary_datasource import BinaryDatasource
from ray.data.datasource.datasource import ReadTask
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
from ray.data.datasource.file_meta_provider import (
BaseFileMetadataProvider,
DefaultFileMetadataProvider,
FastFileMetadataProvider,
)
from ray.data.datasource.partitioning import PathPartitionFilter

if TYPE_CHECKING:
import pyarrow

IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif"]


class ImageFolderDatasource(BinaryDatasource):
"""A datasource that lets you read datasets like `ImageNet <https://www.image-net.org/>`_.

This datasource works with any dataset where images are arranged in this way:

.. code-block::

root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

bveeramani marked this conversation as resolved.
Show resolved Hide resolved
Datasets read with ``ImageFolderDatasource`` contain two columns: ``'image'`` and
``'label'``. The ``'image'`` column contains ``ndarray`` objects of shape
:math:`(H, W, C)`, and the ``label`` column contains strings corresponding to
labels.

Examples:
>>> import ray
>>> from ray.data.datasource import ImageFolderDatasource
>>>
>>> ds = ray.data.read_datasource( # doctest: +SKIP
... ImageFolderDatasource(),
... paths=["/data/imagenet/train"]
... )
>>> sample = ds.take(1)[0] # doctest: +SKIP
>>> sample["image"].shape # doctest: +SKIP
(469, 387, 3)
>>> sample["label"] # doctest: +SKIP
'n01443537'

Raises:
ValueError: if more than one path is provided. You should only provide the path
to the dataset root.
"""

def prepare_read(
self,
parallelism: int,
paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
open_stream_args: Optional[Dict[str, Any]] = None,
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
partition_filter: PathPartitionFilter = None,
# TODO(ekl) deprecate this once read fusion is available.
_block_udf: Optional[Callable[[Block], Block]] = None,
**reader_args,
) -> List[ReadTask]:
if len(paths) != 1:
raise ValueError(
"`ImageFolderDatasource` expects 1 path representing the dataset "
f"root, but it got {len(paths)} paths instead. To fix this error, "
"pass in a single-element list containing the dataset root (for "
'example, `paths=["s3://imagenet/train"]`)'
)

try:
import imageio # noqa: F401
except ImportError:
raise ValueError(
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
"`ImageFolderDatasource` depends on 'imageio', but 'imageio' couldn't "
"be imported. You can install 'imageio' by running "
"`pip install imageio`."
)

# We call `_resolve_paths_and_filesystem` so that the dataset root is formatted
# in the same way as the paths passed to `_get_class_from_path`.
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
self.root = paths[0]

paths, _ = meta_provider.expand_paths(paths, filesystem)
paths = [path for path in paths if _is_image_file(path)]
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

return super().prepare_read(
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
parallelism=parallelism,
paths=paths,
filesystem=filesystem,
schema=schema,
open_stream_args=open_stream_args,
meta_provider=FastFileMetadataProvider(),
partition_filter=partition_filter,
_block_udf=_block_udf,
)

def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
import pandas as pd

records = super()._read_file(f, path, include_paths=True)
assert len(records) == 1
path, data = records[0]

image = iio.imread(data)
label = _get_class_from_path(path, self.root)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any docs / past discussion about this part? Basically we're assuming we get the label based on user file path, which has to be structured in certain way in order to get the correct one without knobs needed to pass in custom label file or join ?

For example, if i read a s3 bucket with filenames of "dog.jpg", "dog_2.jpg" my dataloader will end up getting these string values by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we're assuming we get the label based on user file path, which has to be structured in certain way in order to get the correct one without knobs needed to pass in custom label file or join ?

Yeah, that's right. The datasource assumes that the layout is structured in the same way as ImageNet. The functionality of the datasource is based on that of TorchVision's ImageFolder.

For example, if i read a s3 bucket with filenames of "dog.jpg", "dog_2.jpg" my dataloader will end up getting these string values by default.

Yeah, you're right. We don't validate that the label corresponds to a directory. In this case, we could raise an error stating that the folder isn't structured correctly.

Alternatively, if images aren't stored in a directory, we could set the label to None.

If images aren't stored in a sub-directory, then the image's label will be set to `None`.

.. code-block::

    root/dog/xxx.png  # Label is 'dog'
    root/123.jpg.     # Label is `None`


return pd.DataFrame({"image": [np.array(image)], "label": [label]})
bveeramani marked this conversation as resolved.
Show resolved Hide resolved


def _is_image_file(path: str) -> bool:
return any(path.lower().endswith(extension) for extension in IMAGE_EXTENSIONS)


def _get_class_from_path(path: str, root: str) -> str:
# The class is the name of the first directory after the root. For example, if
# the root is "/data/imagenet/train" and the path is
# "/data/imagenet/train/n01443537/images/n01443537_0.JPEG", then the class is
# "n01443537".
path, root = pathlib.PurePath(path), pathlib.PurePath(root)
assert root in path.parents
return path.parts[len(root.parts) :][0]
Binary file added python/ray/data/tests/image-folder/cat/123.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Binary file added python/ray/data/tests/image-folder/dog/xxx.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
21 changes: 21 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DefaultFileMetadataProvider,
DefaultParquetMetadataProvider,
FastFileMetadataProvider,
ImageFolderDatasource,
PathPartitionFilter,
PathPartitionEncoder,
PartitionStyle,
Expand Down Expand Up @@ -2534,6 +2535,26 @@ def get_node_id():
assert node_ids == {bar_node_id}


def test_image_folder_datasource(ray_start_regular_shared):
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
root = os.path.join(os.path.dirname(__file__), "image-folder")
ds = ray.data.read_datasource(ImageFolderDatasource(), paths=[root])

assert ds.count() == 2

df = ds.to_pandas()
assert set(df["label"]) == {"cat", "dog"}
assert all(isinstance(array, np.ndarray) for array in df["image"])
assert all(array.shape == (32, 32, 3) for array in df["image"])


def test_image_folder_datasource_raises_value_error(ray_start_regular_shared):
# `ImageFolderDatasource` should raise an error if more than one path is passed.
with pytest.raises(ValueError):
ray.data.read_datasource(
ImageFolderDatasource(), paths=["imagenet/train", "imagenet/test"]
)


if __name__ == "__main__":
import sys

Expand Down