Skip to content

Commit

Permalink
[SPARK-41034][CONNECT][PYTHON] Connect DataFrame should require a Rem…
Browse files Browse the repository at this point in the history
…oteSparkSession

### What changes were proposed in this pull request?

Connect have marked the session parameter everywhere as `Optional`.  This seems to be only useful for testing where test case can ignore the session parameter when it is not applicable.

However:
1. There are PySpark DataFrame API that returns SparkSession which is not optional. If Connect keep it as optional then we will have diff on such API.
2. Optional suggests `None` check which seems to not be necessary at many places (or forget to check `None` etc.)

This PR proposes to remove the `Optional` on the session from API interface.

### Why are the changes needed?

Maintainability

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing UT

Closes #38541 from amaliujia/dataframe_must_have_a_session.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
amaliujia authored and HyukjinKwon committed Nov 10, 2022
1 parent 381b67b commit 110b516
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 48 deletions.
12 changes: 6 additions & 6 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
import uuid
from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any
from typing import cast, get_args, TYPE_CHECKING, Callable, Any

import decimal
import datetime
Expand Down Expand Up @@ -76,7 +76,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
def __init__(self) -> None:
pass

def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression":
def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
...

def __str__(self) -> str:
Expand All @@ -93,7 +93,7 @@ def __init__(self, value: Any) -> None:
super().__init__()
self._value = value

def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression":
def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
"""Converts the literal expression to the literal in proto.
TODO(SPARK-40533) This method always assumes the largest type and can thus
Expand Down Expand Up @@ -181,7 +181,7 @@ def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
Expand All @@ -207,7 +207,7 @@ def __init__(self, col: ColumnRef, ascending: bool = True, nullsLast: bool = Tru
def __str__(self) -> str:
return str(self.ref) + " ASC" if self.ascending else " DESC"

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
return self.ref.to_plan(session)


Expand All @@ -221,7 +221,7 @@ def __init__(
self._args = args
self._op = op

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
fun = proto.Expression()
fun.unresolved_function.parts.append(self._op)
fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args])
Expand Down
16 changes: 9 additions & 7 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,23 @@ class DataFrame(object):
of the DataFrame with the changes applied.
"""

def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None):
def __init__(
self,
session: "RemoteSparkSession",
data: Optional[List[Any]] = None,
schema: Optional[StructType] = None,
):
"""Creates a new data frame"""
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
self._cache: Dict[str, Any] = {}
self._session: Optional["RemoteSparkSession"] = None
self._session: "RemoteSparkSession" = session

@classmethod
def withPlan(
cls, plan: plan.LogicalPlan, session: Optional["RemoteSparkSession"] = None
) -> "DataFrame":
def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame":
"""Main initialization method used to construct a new data frame with a child plan."""
new_frame = DataFrame()
new_frame = DataFrame(session=session)
new_frame._plan = plan
new_frame._session = session
return new_frame

def select(self, *cols: ColumnOrName) -> "DataFrame":
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/function_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self._args = []
self._func_name = None

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
if session is None:
raise Exception("CAnnot create UDF without remote Session.")
# Needs to materialize the UDF to the server
Expand Down
44 changes: 20 additions & 24 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression:
return exp

def to_attr_or_expression(
self, col: "ColumnOrString", session: Optional["RemoteSparkSession"]
self, col: "ColumnOrString", session: "RemoteSparkSession"
) -> proto.Expression:
"""Returns either an instance of an unresolved attribute or the serialized
expression value of the column."""
Expand All @@ -67,7 +67,7 @@ def to_attr_or_expression(
else:
return cast(ColumnRef, col).to_plan(session)

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
...

def _verify(self, session: "RemoteSparkSession") -> bool:
Expand All @@ -82,9 +82,7 @@ def _verify(self, session: "RemoteSparkSession") -> bool:

return test_plan == plan

def to_proto(
self, session: Optional["RemoteSparkSession"] = None, debug: bool = False
) -> proto.Plan:
def to_proto(self, session: "RemoteSparkSession", debug: bool = False) -> proto.Plan:
"""
Generates connect proto plan based on this LogicalPlan.
Expand Down Expand Up @@ -127,7 +125,7 @@ def __init__(
self.schema = schema
self.options = options

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
plan = proto.Relation()
if self.format is not None:
plan.read.data_source.format = self.format
Expand Down Expand Up @@ -158,7 +156,7 @@ def __init__(self, table_name: str) -> None:
super().__init__(None)
self.table_name = table_name

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
plan = proto.Relation()
plan.read.named_table.unparsed_identifier = self.table_name
return plan
Expand Down Expand Up @@ -202,7 +200,7 @@ def _verify_expressions(self) -> None:
f"Only Expressions or String can be used for projections: '{c}'."
)

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
proj_exprs = []
for c in self._raw_columns:
Expand Down Expand Up @@ -241,7 +239,7 @@ def __init__(self, child: Optional["LogicalPlan"], filter: Expression) -> None:
super().__init__(child)
self.filter = filter

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.filter.input.CopyFrom(self._child.plan(session))
Expand Down Expand Up @@ -269,7 +267,7 @@ def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None:
super().__init__(child)
self.limit = limit

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.limit.input.CopyFrom(self._child.plan(session))
Expand Down Expand Up @@ -297,7 +295,7 @@ def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None:
super().__init__(child)
self.offset = offset

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.offset.input.CopyFrom(self._child.plan(session))
Expand Down Expand Up @@ -331,7 +329,7 @@ def __init__(
self.all_columns_as_keys = all_columns_as_keys
self.column_names = column_names

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
Expand Down Expand Up @@ -371,7 +369,7 @@ def __init__(
self.is_global = is_global

def col_to_sort_field(
self, col: Union[SortOrder, ColumnRef, str], session: Optional["RemoteSparkSession"]
self, col: Union[SortOrder, ColumnRef, str], session: "RemoteSparkSession"
) -> proto.Sort.SortField:
if isinstance(col, SortOrder):
sf = proto.Sort.SortField()
Expand All @@ -398,7 +396,7 @@ def col_to_sort_field(
sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST
return sf

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.sort.input.CopyFrom(self._child.plan(session))
Expand Down Expand Up @@ -438,7 +436,7 @@ def __init__(
self.with_replacement = with_replacement
self.seed = seed

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.sample.input.CopyFrom(self._child.plan(session))
Expand Down Expand Up @@ -488,9 +486,7 @@ def __init__(
self.grouping_cols = grouping_cols
self.measures = measures if measures is not None else []

def _convert_measure(
self, m: MeasureType, session: Optional["RemoteSparkSession"]
) -> proto.Expression:
def _convert_measure(self, m: MeasureType, session: "RemoteSparkSession") -> proto.Expression:
exp, fun = m
proto_expr = proto.Expression()
measure = proto_expr.unresolved_function
Expand All @@ -501,7 +497,7 @@ def _convert_measure(
measure.arguments.append(cast(Expression, exp).to_plan(session))
return proto_expr

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
groupings = [x.to_plan(session) for x in self.grouping_cols]

Expand Down Expand Up @@ -571,7 +567,7 @@ def __init__(
)
self.how = join_type

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
rel = proto.Relation()
rel.join.left.CopyFrom(self.left.plan(session))
rel.join.right.CopyFrom(self.right.plan(session))
Expand Down Expand Up @@ -622,7 +618,7 @@ def __init__(
self.is_all = is_all
self.set_op = set_op

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
assert self._child is not None
rel = proto.Relation()
if self._child is not None:
Expand Down Expand Up @@ -715,7 +711,7 @@ def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None:
super().__init__(child)
self._alias = alias

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
rel = proto.Relation()
rel.subquery_alias.alias = self._alias
return rel
Expand All @@ -741,7 +737,7 @@ def __init__(self, query: str) -> None:
super().__init__(None)
self._query = query

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
rel = proto.Relation()
rel.sql.query = self._query
return rel
Expand Down Expand Up @@ -776,7 +772,7 @@ def __init__(
self._step = step
self._num_partitions = num_partitions

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
def plan(self, session: "RemoteSparkSession") -> proto.Relation:
rel = proto.Relation()
rel.range.start = self._start
rel.range.end = self._end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

if have_pandas:
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
import pyspark.sql.connect.functions as fun
Expand All @@ -34,7 +33,7 @@
@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
def test_simple_column_expressions(self):
df = c.DataFrame.withPlan(p.Read("table"))
df = self.connect.with_plan(p.Read("table"))

c1 = df.col_name
self.assertIsInstance(c1, col.ColumnRef)
Expand Down Expand Up @@ -79,7 +78,7 @@ def test_uuid_literal(self):
lit.to_plan(None)

def test_column_literals(self):
df = c.DataFrame.withPlan(p.Read("table"))
df = self.connect.with_plan(p.Read("table"))
lit_df = df.select(fun.lit(10))
self.assertIsNotNone(lit_df._plan.to_proto(None))

Expand Down Expand Up @@ -138,7 +137,7 @@ def test_list_to_literal(self):
def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
the protobuf structure."""
df = c.DataFrame.withPlan(p.Read("table"))
df = self.connect.with_plan(p.Read("table"))

expr = fun.lit(10) < fun.lit(10)
expr_plan = expr.to_plan(None)
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/sql/tests/connect/test_connect_select_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message

if have_pandas:
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read
import pyspark.sql.connect.proto as proto
Expand All @@ -30,17 +29,17 @@
@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
class SparkConnectToProtoSuite(PlanOnlyTestFixture):
def test_select_with_columns_and_strings(self):
df = DataFrame.withPlan(Read("table"))
self.assertIsNotNone(df.select(col("name"))._plan.to_proto())
df = self.connect.with_plan(Read("table"))
self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect))
self.assertIsNotNone(df.select("name"))
self.assertIsNotNone(df.select("name", "name2"))
self.assertIsNotNone(df.select(col("name"), col("name2")))
self.assertIsNotNone(df.select(col("name"), "name2"))
self.assertIsNotNone(df.select("*"))

def test_join_with_join_type(self):
df_left = DataFrame.withPlan(Read("table"))
df_right = DataFrame.withPlan(Read("table"))
df_left = self.connect.with_plan(Read("table"))
df_right = self.connect.with_plan(Read("table"))
for (join_type_str, join_type) in [
(None, proto.Join.JoinType.JOIN_TYPE_INNER),
("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
Expand All @@ -50,7 +49,9 @@ def test_join_with_join_type(self):
("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
]:
joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto()
joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto(
self.connect
)
self.assertEqual(joined_df.root.join.join_type, join_type)


Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL
from pyspark.testing.utils import search_jar
from pyspark.sql.connect.plan import LogicalPlan

connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
else:
Expand Down Expand Up @@ -91,6 +92,10 @@ def _session_range(
def _session_sql(cls, query: str) -> "DataFrame":
return DataFrame.withPlan(SQL(query), cls.connect) # type: ignore

@classmethod
def _with_plan(cls, plan: LogicalPlan) -> "DataFrame":
return DataFrame.withPlan(plan, cls.connect) # type: ignore

@classmethod
def setUpClass(cls: Any) -> None:
cls.connect = MockRemoteSession()
Expand All @@ -100,10 +105,12 @@ def setUpClass(cls: Any) -> None:
cls.connect.set_hook("readTable", cls._read_table)
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)

@classmethod
def tearDownClass(cls: Any) -> None:
cls.connect.drop_hook("register_udf")
cls.connect.drop_hook("readTable")
cls.connect.drop_hook("range")
cls.connect.drop_hook("sql")
cls.connect.drop_hook("with_plan")

0 comments on commit 110b516

Please sign in to comment.