Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] add dataloader for lance datasource #49459

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions python/ray/data/_internal/datasource/lance_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional

import numpy as np
from torch.utils.data import Dataset

from ray.data._internal.util import _check_import, call_with_retry
from ray.data.block import BlockMetadata
Expand Down Expand Up @@ -97,6 +98,34 @@ def estimate_inmemory_data_size(self) -> Optional[int]:
# TODO(chengsu): Add memory size estimation to improve auto-tune of parallelism.
return None

class LanceDataset(Dataset):
"""Custom Dataset to load images and their corresponding captions"""

def __init__(self, ds: "LanceDatasource"):
self.ds = ds
self.lance_ds = ds.lance_ds

def __len__(self):
return self.lance_ds.count_rows()

def __getitem__(self, idx):
# Load the image and caption
return self.lance_ds.take(
[idx],
columns=self.ds.scanner_options["columns"],
filter=self.ds.scanner_options["filter"],
).to_pydict()

def to_torch_dataset(
self,
) -> Dataset:
return self.LanceDataset(ds=self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Ray Data, a Datasource is to used to create a Ray Dataset, ray.data.read_datasource(...).
Users are not supposed to directly use the Datasource class for training ingestion.
If you want to do this, I think you can directly create a torch Dataset based on LanceDB without using Ray Data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a theoretical perspective, directly creating a torch Dataset is also possible. it is inherited from Torch Dataset. The reason for placing it here is mainly that it can be directly converted into a dataset through the datasource, facilitating the implementation of the ray train + ray data mode.
WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I still don't think this makes sense.
because "Users are not supposed to directly use the Datasource class for training ingestion."


def count(
self,
) -> int:
return self.lance_ds.count_rows(filter=self.scanner_options["filter"])


def _read_fragments_with_retry(
fragment_ids,
Expand Down
33 changes: 33 additions & 0 deletions python/ray/data/tests/test_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import pytest
from pkg_resources import parse_version
from pytest_lazyfixture import lazy_fixture
from torch.utils.data import DataLoader

import ray
from ray._private.test_utils import wait_for_condition
from ray._private.utils import _get_pyarrow_version
from ray.data import Schema
from ray.data._internal.datasource.lance_datasource import LanceDatasource
from ray.data.datasource.path_util import _unwrap_protocol


Expand Down Expand Up @@ -114,6 +116,37 @@ def test_lance():
wait_for_condition(test_lance, timeout=10)


@pytest.mark.parametrize("data_path", [lazy_fixture("local_path")])
def test_torch_dataset(data_path):
setup_data_path = _unwrap_protocol(data_path)
path = os.path.join(setup_data_path, "test.lance")
num_rows = 1024
data = pa.table(
{
"id": pa.array(range(num_rows)),
"name": pa.array([f"test_{i}" for i in range(num_rows)]),
}
)
lance.write_dataset(data, path, max_rows_per_file=1)

ds = LanceDatasource(path, columns=["name"], filter="id < 50 and id > 10")
assert ds.count() == 39
train_ds = ds.to_torch_dataset()

def custom_collate_fn(batch):
if isinstance(batch[0], dict):
return [item["name"] for item in batch]
else:
return [item[0] for item in batch]

dataloader = DataLoader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc can you show example of how this work with ray.train.torch. TorchTrainer?. Currently it takes Ray Dataset as input and not Datasource.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed using Dataaset.

train_ds, batch_size=10, shuffle=False, collate_fn=custom_collate_fn
)
for batch in dataloader:
assert len(batch) == 10
break


if __name__ == "__main__":
import sys

Expand Down
1 change: 1 addition & 0 deletions python/requirements/ml/data-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pandas==1.5.3; python_version < '3.12'
modin==0.31.0; python_version >= '3.12'
pandas==2.2.2; python_version >= '3.12'
responses==0.13.4
torch==2.3.0
pymars>=0.8.3; python_version < "3.12"
1 change: 1 addition & 0 deletions python/requirements/ml/data-test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ google-api-core==1.34.0
webdataset
raydp==1.7.0b20231020.dev0
pylance
torch==2.3.0
delta-sharing
pytest-mock
decord
Expand Down
2 changes: 2 additions & 0 deletions python/requirements_compiled.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,8 @@ pyjwt==2.8.0
# snowflake-connector-python
pylance==0.10.18
# via -r /ray/ci/../python/requirements/ml/data-test-requirements.txt
torch==2.3.0
# via -r /ray/ci/../python/requirements/ml/data-test-requirements.txt
pymars==0.10.0 ; python_version < "3.12"
# via -r /ray/ci/../python/requirements/ml/data-requirements.txt
pymongo==4.3.2
Expand Down
Loading