-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dataset] implement from_spark
, to_spark
and some optimizations
#17340
Changes from 16 commits
afb1752
619404a
0c8de45
0958dbe
cf9f4f4
68e109f
dec202e
93994c9
48b99c1
407a7e6
0185788
6200f00
6a27cac
2bfc10a
8a7e3bf
6f31d7e
b2b5a6f
00272ca
086f571
d99c7f4
c4f62d2
edfae6e
c1a59d2
2734119
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -472,7 +472,6 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]: | |
return Dataset(BlockList(blocks, ray.get(list(metadata)))) | ||
|
||
|
||
@PublicAPI(stability="beta") | ||
def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: | ||
"""Create a dataset from a set of NumPy ndarrays. | ||
|
||
|
@@ -490,34 +489,41 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]: | |
|
||
|
||
@PublicAPI(stability="beta") | ||
def from_arrow(tables: List[ObjectRef["pyarrow.Table"]]) -> Dataset[ArrowRow]: | ||
def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]], | ||
*, | ||
parallelism: int = 200) -> Dataset[ArrowRow]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? This doesn't seem to be used. |
||
"""Create a dataset from a set of Arrow tables. | ||
|
||
Args: | ||
dfs: A list of Ray object references to Arrow tables. | ||
tables: A list of Ray object references to Arrow tables, | ||
or its streaming format in bytes. | ||
parallelism: The amount of parallelism to use for the dataset. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? |
||
|
||
Returns: | ||
Dataset holding Arrow records from the tables. | ||
""" | ||
|
||
get_metadata = cached_remote_fn(_get_metadata) | ||
metadata = [get_metadata.remote(t) for t in tables] | ||
return Dataset(BlockList(tables, ray.get(metadata))) | ||
|
||
|
||
@PublicAPI(stability="beta") | ||
def from_spark(df: "pyspark.sql.DataFrame", *, | ||
parallelism: int = 200) -> Dataset[ArrowRow]: | ||
parallelism: int = 0) -> Dataset[ArrowRow]: | ||
kira-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Create a dataset from a Spark dataframe. | ||
|
||
Args: | ||
spark: A SparkSession, which must be created by RayDP (Spark-on-Ray). | ||
df: A Spark dataframe, which must be created by RayDP (Spark-on-Ray). | ||
parallelism: The amount of parallelism to use for the dataset. | ||
If not provided, it will be equal to the number of partitions of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent |
||
original Spark dataframe. | ||
|
||
Returns: | ||
Dataset holding Arrow records read from the dataframe. | ||
""" | ||
raise NotImplementedError # P2 | ||
import raydp | ||
return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism) | ||
|
||
|
||
def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
kira-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import ray | ||
import raydp | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def spark_on_ray_small(request): | ||
ray.init(num_cpus=2, include_dashboard=False) | ||
spark = raydp.init_spark("test", 1, 1, "500 M") | ||
|
||
def stop_all(): | ||
raydp.stop_spark() | ||
ray.shutdown() | ||
|
||
request.addfinalizer(stop_all) | ||
return spark | ||
|
||
|
||
def test_raydp_roundtrip(spark_on_ray_small): | ||
spark = spark_on_ray_small | ||
spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], | ||
["one", "two"]) | ||
rows = [(r.one, r.two) for r in spark_df.take(3)] | ||
ds = ray.data.from_spark(spark_df) | ||
values = [(r["one"], r["two"]) for r in ds.take(6)] | ||
assert values == rows | ||
df = ds.to_spark(spark) | ||
rows_2 = [(r.one, r.two) for r in df.take(3)] | ||
assert values == rows_2 | ||
|
||
|
||
def test_raydp_to_spark(spark_on_ray_small): | ||
spark = spark_on_ray_small | ||
n = 5 | ||
ds = ray.data.range_arrow(n) | ||
values = [r["value"] for r in ds.take(5)] | ||
df = ds.to_spark(spark) | ||
rows = [r.value for r in df.take(5)] | ||
assert values == rows | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
sys.exit(pytest.main(["-v", __file__])) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,8 @@ moto | |
mypy | ||
networkx | ||
numba | ||
raydp-nightly; platform_system != "Windows" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @clarkzinzow @iycheng could you provide the right place to add this for dataset tests? |
||
|
||
# higher version of llvmlite breaks windows | ||
llvmlite==0.34.0 | ||
openpyxl | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is added because locations of objects need to be passed to java, so that we can ensure locality in spark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you calling this from RayDP? This isn't a public API, so we should call it only from within Ray code (and pass the addresses to external libraries).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment addressed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I added the developer api, I'm going to directly call it in
to_spark
now.