Skip to content

Commit

Permalink
[SPARK-48598][PYTHON][CONNECT] Propagate cached schema in dataframe o…
Browse files Browse the repository at this point in the history
…perations

### What changes were proposed in this pull request?
Propagate cached schema in dataframe operations:

- DataFrame.alias
- DataFrame.coalesce
- DataFrame.repartition
- DataFrame.repartitionByRange
- DataFrame.dropDuplicates
- DataFrame.distinct
- DataFrame.filter
- DataFrame.where
- DataFrame.limit
- DataFrame.sort
- DataFrame.sortWithinPartitions
- DataFrame.orderBy
- DataFrame.sample
- DataFrame.hint
- DataFrame.randomSplit
- DataFrame.observe

### Why are the changes needed?
to avoid unnecessary RPCs if possible

### Does this PR introduce _any_ user-facing change?
No, optimization only

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46954 from zhengruifeng/py_connect_propagate_schema.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jun 12, 2024
1 parent 2d0b122 commit d1d29c9
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 19 deletions.
69 changes: 50 additions & 19 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame:
return self.groupBy().agg(*exprs)

def alias(self, alias: str) -> ParentDataFrame:
return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session)
res = DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session)
res._cached_schema = self._cached_schema
return res

def colRegex(self, colName: str) -> Column:
from pyspark.sql.connect.column import Column as ConnectColumn
Expand Down Expand Up @@ -314,10 +316,12 @@ def coalesce(self, numPartitions: int) -> ParentDataFrame:
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)},
)
return DataFrame(
res = DataFrame(
plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False),
self._session,
)
res._cached_schema = self._cached_schema
return res

@overload
def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
Expand All @@ -340,20 +344,20 @@ def repartition( # type: ignore[misc]
},
)
if len(cols) == 0:
return DataFrame(
res = DataFrame(
plan.Repartition(self._plan, numPartitions, shuffle=True),
self._session,
)
else:
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._to_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]),
self.sparkSession,
)
Expand All @@ -366,6 +370,9 @@ def repartition( # type: ignore[misc]
},
)

res._cached_schema = self._cached_schema
return res

@overload
def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
...
Expand All @@ -392,14 +399,14 @@ def repartitionByRange( # type: ignore[misc]
message_parameters={"item": "cols"},
)
else:
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._sort_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)]
),
Expand All @@ -414,6 +421,9 @@ def repartitionByRange( # type: ignore[misc]
},
)

res._cached_schema = self._cached_schema
return res

def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
# Acceptable args should be str, ... or a single List[str]
# So if subset length is 1, it can be either single str, or a list of str
Expand All @@ -422,20 +432,23 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
assert all(isinstance(c, str) for c in subset)

if not subset:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)
elif len(subset) == 1 and isinstance(subset[0], list):
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, column_names=subset[0]),
session=self._session,
)
else:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)),
session=self._session,
)

res._cached_schema = self._cached_schema
return res

drop_duplicates = dropDuplicates

def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
Expand Down Expand Up @@ -466,9 +479,11 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> Paren
)

def distinct(self) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)
res._cached_schema = self._cached_schema
return res

@overload
def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
Expand Down Expand Up @@ -499,7 +514,9 @@ def filter(self, condition: Union[Column, str]) -> ParentDataFrame:
expr = F.expr(condition)
else:
expr = condition
return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session)
res = DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session)
res._cached_schema = self._cached_schema
return res

def first(self) -> Optional[Row]:
return self.head()
Expand Down Expand Up @@ -709,7 +726,9 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column:
)

def limit(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
res._cached_schema = self._cached_schema
return res

def tail(self, num: int) -> List[Row]:
return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect()
Expand Down Expand Up @@ -766,14 +785,16 @@ def sort(
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
is_global=True,
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

orderBy = sort

Expand All @@ -782,14 +803,16 @@ def sortWithinPartitions(
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
is_global=False,
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def sample(
self,
Expand Down Expand Up @@ -837,7 +860,7 @@ def sample(

seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)

return DataFrame(
res = DataFrame(
plan.Sample(
child=self._plan,
lower_bound=0.0,
Expand All @@ -847,6 +870,8 @@ def sample(
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
return self.withColumnsRenamed({existing: new})
Expand Down Expand Up @@ -1050,10 +1075,12 @@ def hint(
},
)

return DataFrame(
res = DataFrame(
plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def randomSplit(
self,
Expand Down Expand Up @@ -1094,6 +1121,7 @@ def randomSplit(
),
session=self._session,
)
samplePlan._cached_schema = self._cached_schema
splits.append(samplePlan)
j += 1

Expand All @@ -1118,9 +1146,9 @@ def observe(
)

if isinstance(observation, Observation):
return observation._on(self, *exprs)
res = observation._on(self, *exprs)
elif isinstance(observation, str):
return DataFrame(
res = DataFrame(
plan.CollectMetrics(self._plan, observation, list(exprs)),
self._session,
)
Expand All @@ -1133,6 +1161,9 @@ def observe(
},
)

res._cached_schema = self._cached_schema
return res

def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
print(self._show_string(n, truncate, vertical))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType
from pyspark.sql.utils import is_remote

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.sqlutils import (
have_pandas,
Expand Down Expand Up @@ -393,6 +396,38 @@ def test_cached_schema_set_op(self):
# cannot infer when schemas mismatch
self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)

def test_cached_schema_in_chain_op(self):
data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]

cdf = self.connect.createDataFrame(data, ("id", "v1"))
sdf = self.spark.createDataFrame(data, ("id", "v1"))

cdf1 = cdf.withColumn("v2", CF.lit(1))
sdf1 = sdf.withColumn("v2", SF.lit(1))

self.assertTrue(cdf1._cached_schema is None)
# trigger analysis of cdf1.schema
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertTrue(cdf1._cached_schema is not None)

cdf2 = cdf1.where(cdf1.v2 > 0)
sdf2 = sdf1.where(sdf1.v2 > 0)
self.assertEqual(cdf1._cached_schema, cdf2._cached_schema)

cdf3 = cdf2.repartition(10)
sdf3 = sdf2.repartition(10)
self.assertEqual(cdf1._cached_schema, cdf3._cached_schema)

cdf4 = cdf3.distinct()
sdf4 = sdf3.distinct()
self.assertEqual(cdf1._cached_schema, cdf4._cached_schema)

cdf5 = cdf4.sample(fraction=0.5)
sdf5 = sdf4.sample(fraction=0.5)
self.assertEqual(cdf1._cached_schema, cdf5._cached_schema)

self.assertEqual(cdf5.schema, sdf5.schema)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401
Expand Down

0 comments on commit d1d29c9

Please sign in to comment.