Skip to content

Commit

Permalink
[SPARK-44528][CONNECT] Support proper usage of hasattr() for Connect …
Browse files Browse the repository at this point in the history
…dataframe

### What changes were proposed in this pull request?
Currently Connect does not allow the proper usage of Python's `hasattr()` to identify if an attribute is defined or not. This patch fixes that bug (it's working in regular PySpark).

### Why are the changes needed?
Bugfix

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT

Closes apache#42132 from grundprinzip/SPARK-44528.

Lead-authored-by: Martin Grund <martin.grund@databricks.com>
Co-authored-by: Martin Grund <grundprinzip@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
2 people authored and ragnarok56 committed Mar 2, 2024
1 parent 10223e6 commit 2b284b9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,8 +1584,16 @@ def __getattr__(self, name: str) -> "Column":
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": f"{name}()"},
)

if name not in self.columns:
raise AttributeError(
"'%s' object has no attribute '%s'" % (self.__class__.__name__, name)
)

return self[name]

__getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__

@overload
def __getitem__(self, item: Union[int, str]) -> Column:
...
Expand Down
17 changes: 15 additions & 2 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def spark_connect_clean_up_test_data(cls):


class SparkConnectBasicTests(SparkConnectSQLTestCase):
def test_df_getattr_behavior(self):
cdf = self.connect.range(10)
sdf = self.spark.range(10)

sdf._simple_extension = 10
cdf._simple_extension = 10

self.assertEqual(sdf._simple_extension, cdf._simple_extension)
self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))

self.assertTrue(hasattr(cdf, "_simple_extension"))
self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))

def test_df_get_item(self):
# SPARK-41779: test __getitem__

Expand Down Expand Up @@ -1296,8 +1309,8 @@ def test_drop(self):
sdf.drop("a", "x").toPandas(),
)
self.assert_eq(
cdf.drop(cdf.a, cdf.x).toPandas(),
sdf.drop("a", "x").toPandas(),
cdf.drop(cdf.a, "x").toPandas(),
sdf.drop(sdf.a, "x").toPandas(),
)

def test_subquery_alias(self) -> None:
Expand Down
28 changes: 23 additions & 5 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import shutil
import tempfile
import types
import typing
import os
import functools
Expand Down Expand Up @@ -67,7 +68,7 @@

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL
from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan
from pyspark.sql.connect.session import SparkSession


Expand All @@ -88,16 +89,33 @@ def __getattr__(self, item):
return functools.partial(self.hooks[item])


class MockDF(DataFrame):
"""Helper class that must only be used for the mock plan tests."""

def __init__(self, session: SparkSession, plan: LogicalPlan):
super().__init__(session)
self._plan = plan

def __getattr__(self, name):
"""All attributes are resolved to columns, because none really exist in the
mocked DataFrame."""
return self[name]


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
@classmethod
def _read_table(cls, table_name):
return DataFrame.withPlan(Read(table_name), cls.connect)
return cls._df_mock(Read(table_name))

@classmethod
def _udf_mock(cls, *args, **kwargs):
return "internal_name"

@classmethod
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
return MockDF(cls.connect, plan)

@classmethod
def _session_range(
cls,
Expand All @@ -106,17 +124,17 @@ def _session_range(
step=1,
num_partitions=None,
):
return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect)
return cls._df_mock(Range(start, end, step, num_partitions))

@classmethod
def _session_sql(cls, query):
return DataFrame.withPlan(SQL(query), cls.connect)
return cls._df_mock(SQL(query))

if have_pandas:

@classmethod
def _with_plan(cls, plan):
return DataFrame.withPlan(plan, cls.connect)
return cls._df_mock(plan)

@classmethod
def setUpClass(cls):
Expand Down

0 comments on commit 2b284b9

Please sign in to comment.