diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 6429645f0e027..12e424b5ef137 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1584,8 +1584,16 @@ def __getattr__(self, name: str) -> "Column": error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"{name}()"}, ) + + if name not in self.columns: + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) + ) + return self[name] + __getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__ + @overload def __getitem__(self, item: Union[int, str]) -> Column: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 5259ea6b5f520..065f1585a9f06 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -157,6 +157,19 @@ def spark_connect_clean_up_test_data(cls): class SparkConnectBasicTests(SparkConnectSQLTestCase): + def test_df_getattr_behavior(self): + cdf = self.connect.range(10) + sdf = self.spark.range(10) + + sdf._simple_extension = 10 + cdf._simple_extension = 10 + + self.assertEqual(sdf._simple_extension, cdf._simple_extension) + self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension)) + + self.assertTrue(hasattr(cdf, "_simple_extension")) + self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit")) + def test_df_get_item(self): # SPARK-41779: test __getitem__ @@ -1296,8 +1309,8 @@ def test_drop(self): sdf.drop("a", "x").toPandas(), ) self.assert_eq( - cdf.drop(cdf.a, cdf.x).toPandas(), - sdf.drop("a", "x").toPandas(), + cdf.drop(cdf.a, "x").toPandas(), + sdf.drop(sdf.a, "x").toPandas(), ) def test_subquery_alias(self) -> None: diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 1b3ac10fce874..b6145d0a00618 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -16,6 +16,7 @@ # import shutil import tempfile +import types import typing import os import functools @@ -67,7 +68,7 @@ if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.plan import Read, Range, SQL + from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan from pyspark.sql.connect.session import SparkSession @@ -88,16 +89,33 @@ def __getattr__(self, item): return functools.partial(self.hooks[item]) +class MockDF(DataFrame): + """Helper class that must only be used for the mock plan tests.""" + + def __init__(self, session: SparkSession, plan: LogicalPlan): + super().__init__(session) + self._plan = plan + + def __getattr__(self, name): + """All attributes are resolved to columns, because none really exist in the + mocked DataFrame.""" + return self[name] + + @unittest.skipIf(not should_test_connect, connect_requirement_message) class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils): @classmethod def _read_table(cls, table_name): - return DataFrame.withPlan(Read(table_name), cls.connect) + return cls._df_mock(Read(table_name)) @classmethod def _udf_mock(cls, *args, **kwargs): return "internal_name" + @classmethod + def _df_mock(cls, plan: LogicalPlan) -> MockDF: + return MockDF(cls.connect, plan) + @classmethod def _session_range( cls, @@ -106,17 +124,17 @@ def _session_range( step=1, num_partitions=None, ): - return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect) + return cls._df_mock(Range(start, end, step, num_partitions)) @classmethod def _session_sql(cls, query): - return DataFrame.withPlan(SQL(query), cls.connect) + return cls._df_mock(SQL(query)) if have_pandas: @classmethod def _with_plan(cls, plan): - return DataFrame.withPlan(plan, cls.connect) + return cls._df_mock(plan) @classmethod def setUpClass(cls):