Skip to content

Commit

Permalink
[Dataset] implement from_spark, to_spark and some optimizations (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kira-lin authored Sep 9, 2021
1 parent fdd5210 commit 2fcd1bc
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 11 deletions.
6 changes: 6 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,12 @@ cdef class CoreWorker:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)

def get_owner_address(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()
return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress(
c_object_id).SerializeAsString()

def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``.
Block = Union[List[T], np.ndarray, "pyarrow.Table"]
Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes]


@DeveloperAPI
Expand Down Expand Up @@ -124,6 +124,10 @@ def for_block(block: Block) -> "BlockAccessor[T]":
from ray.data.impl.arrow_block import \
ArrowBlockAccessor
return ArrowBlockAccessor(block)
elif isinstance(block, bytes):
from ray.data.impl.arrow_block import \
ArrowBlockAccessor
return ArrowBlockAccessor.from_bytes(block)
elif isinstance(block, list):
from ray.data.impl.simple_block import \
SimpleBlockAccessor
Expand Down
12 changes: 10 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,15 +1269,23 @@ def to_modin(self) -> "modin.DataFrame":
pd_objs = self.to_pandas()
return from_partitions(pd_objs, axis=0)

def to_spark(self) -> "pyspark.sql.DataFrame":
def to_spark(self,
spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame":
"""Convert this dataset into a Spark dataframe.
Time complexity: O(dataset size / parallelism)
Returns:
A Spark dataframe created from this dataset.
"""
raise NotImplementedError # P2
import raydp
core_worker = ray.worker.global_worker.core_worker
locations = [
core_worker.get_owner_address(block)
for block in self.get_blocks()
]
return raydp.spark.ray_dataset_to_spark_dataframe(
spark, self.schema(), self.get_blocks(), locations)

def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]:
"""Convert this dataset into a distributed set of Pandas dataframes.
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/impl/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def __init__(self, table: "pyarrow.Table"):
raise ImportError("Run `pip install pyarrow` for Arrow support")
self._table = table

@classmethod
def from_bytes(cls, data: bytes):
reader = pyarrow.ipc.open_stream(data)
return cls(reader.read_all())

def iter_rows(self) -> Iterator[ArrowRow]:
outer = self

Expand Down
21 changes: 13 additions & 8 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,6 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]:
return Dataset(BlockList(blocks, ray.get(list(metadata))))


@PublicAPI(stability="beta")
def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:
"""Create a dataset from a set of NumPy ndarrays.
Expand All @@ -524,34 +523,40 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:


@PublicAPI(stability="beta")
def from_arrow(tables: List[ObjectRef["pyarrow.Table"]]) -> Dataset[ArrowRow]:
def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]]
) -> Dataset[ArrowRow]:
"""Create a dataset from a set of Arrow tables.
Args:
dfs: A list of Ray object references to Arrow tables.
tables: A list of Ray object references to Arrow tables,
or its streaming format in bytes.
Returns:
Dataset holding Arrow records from the tables.
"""

get_metadata = cached_remote_fn(_get_metadata)
metadata = [get_metadata.remote(t) for t in tables]
return Dataset(BlockList(tables, ray.get(metadata)))


@PublicAPI(stability="beta")
def from_spark(df: "pyspark.sql.DataFrame", *,
parallelism: int = 200) -> Dataset[ArrowRow]:
def from_spark(df: "pyspark.sql.DataFrame",
*,
parallelism: Optional[int] = None) -> Dataset[ArrowRow]:
"""Create a dataset from a Spark dataframe.
Args:
spark: A SparkSession, which must be created by RayDP (Spark-on-Ray).
df: A Spark dataframe, which must be created by RayDP (Spark-on-Ray).
parallelism: The amount of parallelism to use for the dataset.
parallelism: The amount of parallelism to use for the dataset.
If not provided, it will be equal to the number of partitions of
the original Spark dataframe.
Returns:
Dataset holding Arrow records read from the dataframe.
"""
raise NotImplementedError # P2
import raydp
return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism)


def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]:
Expand Down
44 changes: 44 additions & 0 deletions python/ray/data/tests/test_raydp_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import ray
import raydp


@pytest.fixture(scope="function")
def spark_on_ray_small(request):
ray.init(num_cpus=2, include_dashboard=False)
spark = raydp.init_spark("test", 1, 1, "500 M")

def stop_all():
raydp.stop_spark()
ray.shutdown()

request.addfinalizer(stop_all)
return spark


def test_raydp_roundtrip(spark_on_ray_small):
spark = spark_on_ray_small
spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")],
["one", "two"])
rows = [(r.one, r.two) for r in spark_df.take(3)]
ds = ray.data.from_spark(spark_df)
values = [(r["one"], r["two"]) for r in ds.take(6)]
assert values == rows
df = ds.to_spark(spark)
rows_2 = [(r.one, r.two) for r in df.take(3)]
assert values == rows_2


def test_raydp_to_spark(spark_on_ray_small):
spark = spark_on_ray_small
n = 5
ds = ray.data.range_arrow(n)
values = [r["value"] for r in ds.take(5)]
df = ds.to_spark(spark)
rows = [r.value for r in df.take(5)]
assert values == rows


if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
1 change: 1 addition & 0 deletions python/requirements/data_processing/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ s3fs
modin>=0.8.3; python_version < '3.7'
modin>=0.10.0; python_version >= '3.7'
pytest-repeat
raydp-nightly

0 comments on commit 2fcd1bc

Please sign in to comment.