Skip to content

[SPARK-48564][PYTHON][CONNECT] Propagate cached schema in set operations #46915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,24 +1131,33 @@ def observe(
def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
print(self._show_string(n, truncate, vertical))

def _merge_cached_schema(self, other: ParentDataFrame) -> Optional[StructType]:
# to avoid type coercion, only propagate the schema
# when the cached schemas are exactly the same
if self._cached_schema is not None and self._cached_schema == other._cached_schema:
return self.schema
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just remove this too


def union(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return self.unionAll(other)

def unionAll(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan, other._plan, "union", is_all=True # type: ignore[arg-type]
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def unionByName(
self, other: ParentDataFrame, allowMissingColumns: bool = False
) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan,
other._plan, # type: ignore[arg-type]
Expand All @@ -1158,42 +1167,52 @@ def unionByName(
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def subtract(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan, other._plan, "except", is_all=False # type: ignore[arg-type]
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan, other._plan, "except", is_all=True # type: ignore[arg-type]
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan, other._plan, "intersect", is_all=False # type: ignore[arg-type]
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
return DataFrame(
res = DataFrame(
plan.SetOperation(
self._plan, other._plan, "intersect", is_all=True # type: ignore[arg-type]
),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res

def where(self, condition: Union[Column, str]) -> ParentDataFrame:
if not isinstance(condition, (str, Column)):
Expand Down
100 changes: 100 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,106 @@ def summarize(left, right):
self.assertEqual(cdf3.schema, sdf3.schema)
self.assertEqual(cdf3.collect(), sdf3.collect())

def test_cached_schema_set_op(self):
data1 = [(1, 2, 3)]
data2 = [(6, 2, 5)]
data3 = [(6, 2, 5.0)]

cdf1 = self.connect.createDataFrame(data1, ["a", "b", "c"])
sdf1 = self.spark.createDataFrame(data1, ["a", "b", "c"])
cdf2 = self.connect.createDataFrame(data2, ["a", "b", "c"])
sdf2 = self.spark.createDataFrame(data2, ["a", "b", "c"])
cdf3 = self.connect.createDataFrame(data3, ["a", "b", "c"])
sdf3 = self.spark.createDataFrame(data3, ["a", "b", "c"])

# schema not yet cached
self.assertTrue(cdf1._cached_schema is None)
self.assertTrue(cdf2._cached_schema is None)
self.assertTrue(cdf3._cached_schema is None)

# no cached schema in result dataframe
self.assertTrue(cdf1.union(cdf1)._cached_schema is None)
self.assertTrue(cdf1.union(cdf2)._cached_schema is None)
self.assertTrue(cdf1.union(cdf3)._cached_schema is None)

self.assertTrue(cdf1.unionAll(cdf1)._cached_schema is None)
self.assertTrue(cdf1.unionAll(cdf2)._cached_schema is None)
self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None)

self.assertTrue(cdf1.unionByName(cdf1)._cached_schema is None)
self.assertTrue(cdf1.unionByName(cdf2)._cached_schema is None)
self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None)

self.assertTrue(cdf1.subtract(cdf1)._cached_schema is None)
self.assertTrue(cdf1.subtract(cdf2)._cached_schema is None)
self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None)

self.assertTrue(cdf1.exceptAll(cdf1)._cached_schema is None)
self.assertTrue(cdf1.exceptAll(cdf2)._cached_schema is None)
self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None)

self.assertTrue(cdf1.intersect(cdf1)._cached_schema is None)
self.assertTrue(cdf1.intersect(cdf2)._cached_schema is None)
self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None)

self.assertTrue(cdf1.intersectAll(cdf1)._cached_schema is None)
self.assertTrue(cdf1.intersectAll(cdf2)._cached_schema is None)
self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)

# trigger analysis of cdf1.schema
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertTrue(cdf1._cached_schema is not None)

self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema)
# cannot infer when cdf2 doesn't cache schema
self.assertTrue(cdf1.union(cdf2)._cached_schema is None)
# cannot infer when cdf3 doesn't cache schema
self.assertTrue(cdf1.union(cdf3)._cached_schema is None)

# trigger analysis of cdf2.schema, cdf3.schema
self.assertEqual(cdf2.schema, sdf2.schema)
self.assertEqual(cdf3.schema, sdf3.schema)

# now all the schemas are cached
self.assertTrue(cdf1._cached_schema is not None)
self.assertTrue(cdf2._cached_schema is not None)
self.assertTrue(cdf3._cached_schema is not None)

self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.union(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.union(cdf3)._cached_schema is None)

self.assertEqual(cdf1.unionAll(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.unionAll(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None)

self.assertEqual(cdf1.unionByName(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.unionByName(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None)

self.assertEqual(cdf1.subtract(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.subtract(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None)

self.assertEqual(cdf1.exceptAll(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.exceptAll(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None)

self.assertEqual(cdf1.intersect(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.intersect(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None)

self.assertEqual(cdf1.intersectAll(cdf1)._cached_schema, cdf1._cached_schema)
self.assertEqual(cdf1.intersectAll(cdf2)._cached_schema, cdf1._cached_schema)
# cannot infer when schemas mismatch
self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401
Expand Down