-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[data] add dataloader for lance datasource #49459
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,13 @@ | |
import pytest | ||
from pkg_resources import parse_version | ||
from pytest_lazyfixture import lazy_fixture | ||
from torch.utils.data import DataLoader | ||
|
||
import ray | ||
from ray._private.test_utils import wait_for_condition | ||
from ray._private.utils import _get_pyarrow_version | ||
from ray.data import Schema | ||
from ray.data._internal.datasource.lance_datasource import LanceDatasource | ||
from ray.data.datasource.path_util import _unwrap_protocol | ||
|
||
|
||
|
@@ -114,6 +116,37 @@ def test_lance(): | |
wait_for_condition(test_lance, timeout=10) | ||
|
||
|
||
@pytest.mark.parametrize("data_path", [lazy_fixture("local_path")]) | ||
def test_torch_dataset(data_path): | ||
setup_data_path = _unwrap_protocol(data_path) | ||
path = os.path.join(setup_data_path, "test.lance") | ||
num_rows = 1024 | ||
data = pa.table( | ||
{ | ||
"id": pa.array(range(num_rows)), | ||
"name": pa.array([f"test_{i}" for i in range(num_rows)]), | ||
} | ||
) | ||
lance.write_dataset(data, path, max_rows_per_file=1) | ||
|
||
ds = LanceDatasource(path, columns=["name"], filter="id < 50 and id > 10") | ||
assert ds.count() == 39 | ||
train_ds = ds.to_torch_dataset() | ||
|
||
def custom_collate_fn(batch): | ||
if isinstance(batch[0], dict): | ||
return [item["name"] for item in batch] | ||
else: | ||
return [item[0] for item in batch] | ||
|
||
dataloader = DataLoader( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ooc can you show example of how this work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is indeed using Dataaset. |
||
train_ds, batch_size=10, shuffle=False, collate_fn=custom_collate_fn | ||
) | ||
for batch in dataloader: | ||
assert len(batch) == 10 | ||
break | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Ray Data, a Datasource is to used to create a Ray Dataset,
ray.data.read_datasource(...)
.Users are not supposed to directly use the Datasource class for training ingestion.
If you want to do this, I think you can directly create a torch Dataset based on LanceDB without using Ray Data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a theoretical perspective, directly creating a torch Dataset is also possible. it is inherited from Torch Dataset. The reason for placing it here is mainly that it can be directly converted into a dataset through the datasource, facilitating the implementation of the ray train + ray data mode.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I still don't think this makes sense.
because "Users are not supposed to directly use the Datasource class for training ingestion."