diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 42466fa169922..3c9f8c3d7364e 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -15,7 +15,7 @@ # limitations under the License. # import uuid -from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any +from typing import cast, get_args, TYPE_CHECKING, Callable, Any import decimal import datetime @@ -76,7 +76,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override] def __init__(self) -> None: pass - def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression": + def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": ... def __str__(self) -> str: @@ -93,7 +93,7 @@ def __init__(self, value: Any) -> None: super().__init__() self._value = value - def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression": + def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": """Converts the literal expression to the literal in proto. TODO(SPARK-40533) This method always assumes the largest type and can thus @@ -181,7 +181,7 @@ def name(self) -> str: """Returns the qualified name of the column reference.""" return self._unparsed_identifier - def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: """Returns the Proto representation of the expression.""" expr = proto.Expression() expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier @@ -207,7 +207,7 @@ def __init__(self, col: ColumnRef, ascending: bool = True, nullsLast: bool = Tru def __str__(self) -> str: return str(self.ref) + " ASC" if self.ascending else " DESC" - def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: return self.ref.to_plan(session) @@ -221,7 +221,7 @@ def __init__( self._args = args self._op = op - def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: fun = proto.Expression() fun.unresolved_function.parts.append(self._op) fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args]) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index ccd826cd476fd..6423a22c82241 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -101,21 +101,23 @@ class DataFrame(object): of the DataFrame with the changes applied. """ - def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None): + def __init__( + self, + session: "RemoteSparkSession", + data: Optional[List[Any]] = None, + schema: Optional[StructType] = None, + ): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None self._cache: Dict[str, Any] = {} - self._session: Optional["RemoteSparkSession"] = None + self._session: "RemoteSparkSession" = session @classmethod - def withPlan( - cls, plan: plan.LogicalPlan, session: Optional["RemoteSparkSession"] = None - ) -> "DataFrame": + def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame": """Main initialization method used to construct a new data frame with a child plan.""" - new_frame = DataFrame() + new_frame = DataFrame(session=session) new_frame._plan = plan - new_frame._session = session return new_frame def select(self, *cols: ColumnOrName) -> "DataFrame": diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index ae5a59457ecbe..9c519312a4f43 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -87,7 +87,7 @@ def __init__( self._args = [] self._func_name = None - def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: if session is None: raise Exception("CAnnot create UDF without remote Session.") # Needs to materialize the UDF to the server diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index acc5927b5194d..acca6f96ea862 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -58,7 +58,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression: return exp def to_attr_or_expression( - self, col: "ColumnOrString", session: Optional["RemoteSparkSession"] + self, col: "ColumnOrString", session: "RemoteSparkSession" ) -> proto.Expression: """Returns either an instance of an unresolved attribute or the serialized expression value of the column.""" @@ -67,7 +67,7 @@ def to_attr_or_expression( else: return cast(ColumnRef, col).to_plan(session) - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: ... def _verify(self, session: "RemoteSparkSession") -> bool: @@ -82,9 +82,7 @@ def _verify(self, session: "RemoteSparkSession") -> bool: return test_plan == plan - def to_proto( - self, session: Optional["RemoteSparkSession"] = None, debug: bool = False - ) -> proto.Plan: + def to_proto(self, session: "RemoteSparkSession", debug: bool = False) -> proto.Plan: """ Generates connect proto plan based on this LogicalPlan. @@ -127,7 +125,7 @@ def __init__( self.schema = schema self.options = options - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: plan = proto.Relation() if self.format is not None: plan.read.data_source.format = self.format @@ -158,7 +156,7 @@ def __init__(self, table_name: str) -> None: super().__init__(None) self.table_name = table_name - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: plan = proto.Relation() plan.read.named_table.unparsed_identifier = self.table_name return plan @@ -202,7 +200,7 @@ def _verify_expressions(self) -> None: f"Only Expressions or String can be used for projections: '{c}'." ) - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None proj_exprs = [] for c in self._raw_columns: @@ -241,7 +239,7 @@ def __init__(self, child: Optional["LogicalPlan"], filter: Expression) -> None: super().__init__(child) self.filter = filter - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.filter.input.CopyFrom(self._child.plan(session)) @@ -269,7 +267,7 @@ def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.limit.input.CopyFrom(self._child.plan(session)) @@ -297,7 +295,7 @@ def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None: super().__init__(child) self.offset = offset - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.offset.input.CopyFrom(self._child.plan(session)) @@ -331,7 +329,7 @@ def __init__( self.all_columns_as_keys = all_columns_as_keys self.column_names = column_names - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys @@ -371,7 +369,7 @@ def __init__( self.is_global = is_global def col_to_sort_field( - self, col: Union[SortOrder, ColumnRef, str], session: Optional["RemoteSparkSession"] + self, col: Union[SortOrder, ColumnRef, str], session: "RemoteSparkSession" ) -> proto.Sort.SortField: if isinstance(col, SortOrder): sf = proto.Sort.SortField() @@ -398,7 +396,7 @@ def col_to_sort_field( sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST return sf - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sort.input.CopyFrom(self._child.plan(session)) @@ -438,7 +436,7 @@ def __init__( self.with_replacement = with_replacement self.seed = seed - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sample.input.CopyFrom(self._child.plan(session)) @@ -488,9 +486,7 @@ def __init__( self.grouping_cols = grouping_cols self.measures = measures if measures is not None else [] - def _convert_measure( - self, m: MeasureType, session: Optional["RemoteSparkSession"] - ) -> proto.Expression: + def _convert_measure(self, m: MeasureType, session: "RemoteSparkSession") -> proto.Expression: exp, fun = m proto_expr = proto.Expression() measure = proto_expr.unresolved_function @@ -501,7 +497,7 @@ def _convert_measure( measure.arguments.append(cast(Expression, exp).to_plan(session)) return proto_expr - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None groupings = [x.to_plan(session) for x in self.grouping_cols] @@ -571,7 +567,7 @@ def __init__( ) self.how = join_type - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: rel = proto.Relation() rel.join.left.CopyFrom(self.left.plan(session)) rel.join.right.CopyFrom(self.right.plan(session)) @@ -622,7 +618,7 @@ def __init__( self.is_all = is_all self.set_op = set_op - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: assert self._child is not None rel = proto.Relation() if self._child is not None: @@ -715,7 +711,7 @@ def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None: super().__init__(child) self._alias = alias - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: rel = proto.Relation() rel.subquery_alias.alias = self._alias return rel @@ -741,7 +737,7 @@ def __init__(self, query: str) -> None: super().__init__(None) self._query = query - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: rel = proto.Relation() rel.sql.query = self._query return rel @@ -776,7 +772,7 @@ def __init__( self._step = step self._num_partitions = num_partitions - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + def plan(self, session: "RemoteSparkSession") -> proto.Relation: rel = proto.Relation() rel.range.start = self._start rel.range.end = self._end diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 8773fe4aceba3..ca75b14bb674d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -25,7 +25,6 @@ if have_pandas: from pyspark.sql.connect.proto import Expression as ProtoExpression - import pyspark.sql.connect as c import pyspark.sql.connect.plan as p import pyspark.sql.connect.column as col import pyspark.sql.connect.functions as fun @@ -34,7 +33,7 @@ @unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): def test_simple_column_expressions(self): - df = c.DataFrame.withPlan(p.Read("table")) + df = self.connect.with_plan(p.Read("table")) c1 = df.col_name self.assertIsInstance(c1, col.ColumnRef) @@ -79,7 +78,7 @@ def test_uuid_literal(self): lit.to_plan(None) def test_column_literals(self): - df = c.DataFrame.withPlan(p.Read("table")) + df = self.connect.with_plan(p.Read("table")) lit_df = df.select(fun.lit(10)) self.assertIsNotNone(lit_df._plan.to_proto(None)) @@ -138,7 +137,7 @@ def test_list_to_literal(self): def test_column_expressions(self): """Test a more complex combination of expressions and their translation into the protobuf structure.""" - df = c.DataFrame.withPlan(p.Read("table")) + df = self.connect.with_plan(p.Read("table")) expr = fun.lit(10) < fun.lit(10) expr_plan = expr.to_plan(None) diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index a29c705414624..f380087bedb8c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -21,7 +21,6 @@ from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message if have_pandas: - from pyspark.sql.connect import DataFrame from pyspark.sql.connect.functions import col from pyspark.sql.connect.plan import Read import pyspark.sql.connect.proto as proto @@ -30,8 +29,8 @@ @unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_columns_and_strings(self): - df = DataFrame.withPlan(Read("table")) - self.assertIsNotNone(df.select(col("name"))._plan.to_proto()) + df = self.connect.with_plan(Read("table")) + self.assertIsNotNone(df.select(col("name"))._plan.to_proto(self.connect)) self.assertIsNotNone(df.select("name")) self.assertIsNotNone(df.select("name", "name2")) self.assertIsNotNone(df.select(col("name"), col("name2"))) @@ -39,8 +38,8 @@ def test_select_with_columns_and_strings(self): self.assertIsNotNone(df.select("*")) def test_join_with_join_type(self): - df_left = DataFrame.withPlan(Read("table")) - df_right = DataFrame.withPlan(Read("table")) + df_left = self.connect.with_plan(Read("table")) + df_right = self.connect.with_plan(Read("table")) for (join_type_str, join_type) in [ (None, proto.Join.JoinType.JOIN_TYPE_INNER), ("inner", proto.Join.JoinType.JOIN_TYPE_INNER), @@ -50,7 +49,9 @@ def test_join_with_join_type(self): ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI), ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI), ]: - joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto() + joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.to_proto( + self.connect + ) self.assertEqual(joined_df.root.join.join_type, join_type) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index fed096f45973c..b7c49a6df545e 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -24,6 +24,7 @@ from pyspark.sql.connect import DataFrame from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar + from pyspark.sql.connect.plan import LogicalPlan connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") else: @@ -91,6 +92,10 @@ def _session_range( def _session_sql(cls, query: str) -> "DataFrame": return DataFrame.withPlan(SQL(query), cls.connect) # type: ignore + @classmethod + def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": + return DataFrame.withPlan(plan, cls.connect) # type: ignore + @classmethod def setUpClass(cls: Any) -> None: cls.connect = MockRemoteSession() @@ -100,6 +105,7 @@ def setUpClass(cls: Any) -> None: cls.connect.set_hook("readTable", cls._read_table) cls.connect.set_hook("range", cls._session_range) cls.connect.set_hook("sql", cls._session_sql) + cls.connect.set_hook("with_plan", cls._with_plan) @classmethod def tearDownClass(cls: Any) -> None: @@ -107,3 +113,4 @@ def tearDownClass(cls: Any) -> None: cls.connect.drop_hook("readTable") cls.connect.drop_hook("range") cls.connect.drop_hook("sql") + cls.connect.drop_hook("with_plan")