Skip to content

Commit

Permalink
[SPARK-48533][CONNECT][PYTHON][TESTS] Add test for cached schema
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add test for cached schema, to make Spark Classic's mapInXXX also works within `SparkConnectSQLTestCase`, also add a new `contextmanager` for `os.environ`

### Why are the changes needed?
test coverage

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#46871 from zhengruifeng/test_cached_schema.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jun 5, 2024
1 parent 4075ce6 commit adbfd17
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ def __hash__(self):
# sql unittests
"pyspark.sql.tests.connect.test_connect_plan",
"pyspark.sql.tests.connect.test_connect_basic",
"pyspark.sql.tests.connect.test_connect_dataframe_property",
"pyspark.sql.tests.connect.test_connect_error",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_collection",
Expand Down
141 changes: 141 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.utils import is_remote

from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)

if have_pyarrow:
import pyarrow as pa

if have_pandas:
import pandas as pd


class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
def test_cached_schema_to(self):
cdf = self.connect.read.table(self.tbl_name)
sdf = self.spark.read.table(self.tbl_name)

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
]
)

cdf1 = cdf.to(schema)
self.assertEqual(cdf1._cached_schema, schema)

sdf1 = sdf.to(schema)

self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_map_in_pandas(self):
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
cdf = self.connect.createDataFrame(data, "a int, b string")
sdf = self.spark.createDataFrame(data, "a int, b string")

def func(iterator):
for pdf in iterator:
assert isinstance(pdf, pd.DataFrame)
assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"]
yield pdf

schema = StructType(
[
StructField("a", IntegerType(), True),
StructField("b", StringType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf1 = cdf.mapInPandas(func, schema)
self.assertEqual(cdf1._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
# 'mapInPandas' depends on the method 'pandas_udf', which is dispatched
# based on 'is_remote'. However, in SparkConnectSQLTestCase, the remote
# mode is always on, so 'sdf.mapInPandas' fails with incorrect dispatch.
# Using this temp env to properly invoke mapInPandas in PySpark Classic.
self.assertFalse(is_remote())
sdf1 = sdf.mapInPandas(func, schema)

self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_cached_schema_map_in_arrow(self):
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
cdf = self.connect.createDataFrame(data, "a int, b string")
sdf = self.spark.createDataFrame(data, "a int, b string")

def func(iterator):
for batch in iterator:
assert isinstance(batch, pa.RecordBatch)
assert batch.schema.types == [pa.int32(), pa.string()]
yield batch

schema = StructType(
[
StructField("a", IntegerType(), True),
StructField("b", StringType(), True),
]
)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
self.assertTrue(is_remote())
cdf1 = cdf.mapInArrow(func, schema)
self.assertEqual(cdf1._cached_schema, schema)

with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf1 = sdf.mapInArrow(func, schema)

self.assertEqual(cdf1.schema, sdf1.schema)
self.assertEqual(cdf1.collect(), sdf1.collect())


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None

unittest.main(testRunner=testRunner, verbosity=2)
23 changes: 23 additions & 0 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,29 @@ def function(self, *functions):
for f in functions:
self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)

@contextmanager
def temp_env(self, pairs):
assert isinstance(pairs, dict), "pairs should be a dictionary."

keys = pairs.keys()
new_values = pairs.values()
old_values = [os.environ.get(key, None) for key in keys]
for key, new_value in zip(keys, new_values):
if new_value is None:
if key in os.environ:
del os.environ[key]
else:
os.environ[key] = new_value
try:
yield
finally:
for key, old_value in zip(keys, old_values):
if old_value is None:
if key in os.environ:
del os.environ[key]
else:
os.environ[key] = old_value

@staticmethod
def assert_close(a, b):
c = [j[0] for j in b]
Expand Down

0 comments on commit adbfd17

Please sign in to comment.