Skip to content

Commit

Permalink
[data] add dataloader for lance datasource
Browse files Browse the repository at this point in the history
Signed-off-by: jukejian <jukejian@bytedance.com>
  • Loading branch information
Jay-ju committed Jan 2, 2025
1 parent 93aab20 commit 05b0e09
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/ray/data/_internal/datasource/lance_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional

import numpy as np
from torch.utils.data import Dataset

from ray.data._internal.util import _check_import, call_with_retry
from ray.data.block import BlockMetadata
Expand Down Expand Up @@ -97,6 +98,34 @@ def estimate_inmemory_data_size(self) -> Optional[int]:
# TODO(chengsu): Add memory size estimation to improve auto-tune of parallelism.
return None

class LanceDataset(Dataset):
"""Custom Dataset to load images and their corresponding captions"""

def __init__(self, ds: "LanceDatasource"):
self.ds = ds
self.lance_ds = ds.lance_ds

def __len__(self):
return self.lance_ds.count_rows()

def __getitem__(self, idx):
# Load the image and caption
return self.lance_ds.take(
[idx],
columns=self.ds.scanner_options["columns"],
filter=self.ds.scanner_options["filter"],
).to_pydict()

def to_torch_dataset(
self,
) -> Dataset:
return self.LanceDataset(ds=self)

def count(
self,
) -> int:
return self.lance_ds.count_rows(filter=self.scanner_options["filter"])


def _read_fragments_with_retry(
fragment_ids,
Expand Down
33 changes: 33 additions & 0 deletions python/ray/data/tests/test_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import pytest
from pkg_resources import parse_version
from pytest_lazyfixture import lazy_fixture
from torch.utils.data import DataLoader

import ray
from ray._private.test_utils import wait_for_condition
from ray._private.utils import _get_pyarrow_version
from ray.data import Schema
from ray.data._internal.datasource.lance_datasource import LanceDatasource
from ray.data.datasource.path_util import _unwrap_protocol


Expand Down Expand Up @@ -114,6 +116,37 @@ def test_lance():
wait_for_condition(test_lance, timeout=10)


@pytest.mark.parametrize("data_path", [lazy_fixture("local_path")])
def test_torch_dataset(data_path):
setup_data_path = _unwrap_protocol(data_path)
path = os.path.join(setup_data_path, "test.lance")
num_rows = 1024
data = pa.table(
{
"id": pa.array(range(num_rows)),
"name": pa.array([f"test_{i}" for i in range(num_rows)]),
}
)
lance.write_dataset(data, path, max_rows_per_file=1)

ds = LanceDatasource(path, columns=["name"], filter="id < 50 and id > 10")
assert ds.count() == 39
train_ds = ds.to_torch_dataset()

def custom_collate_fn(batch):
if isinstance(batch[0], dict):
return [item["name"] for item in batch]
else:
return [item[0] for item in batch]

dataloader = DataLoader(
train_ds, batch_size=10, shuffle=False, collate_fn=custom_collate_fn
)
for batch in dataloader:
assert len(batch) == 10
break


if __name__ == "__main__":
import sys

Expand Down
1 change: 1 addition & 0 deletions python/requirements/ml/data-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pandas==1.5.3; python_version < '3.12'
modin==0.31.0; python_version >= '3.12'
pandas==2.2.2; python_version >= '3.12'
responses==0.13.4
torch==2.3.0
pymars>=0.8.3; python_version < "3.12"
1 change: 1 addition & 0 deletions python/requirements/ml/data-test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ google-api-core==1.34.0
webdataset
raydp==1.7.0b20231020.dev0
pylance
torch==2.3.0
delta-sharing
pytest-mock
decord
Expand Down
2 changes: 2 additions & 0 deletions python/requirements_compiled.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,8 @@ pyjwt==2.8.0
# snowflake-connector-python
pylance==0.10.18
# via -r /ray/ci/../python/requirements/ml/data-test-requirements.txt
torch==2.3.0
# via -r /ray/ci/../python/requirements/ml/data-test-requirements.txt
pymars==0.10.0 ; python_version < "3.12"
# via -r /ray/ci/../python/requirements/ml/data-requirements.txt
pymongo==4.3.2
Expand Down

0 comments on commit 05b0e09

Please sign in to comment.