From 913a0f7813c5b2d2bf105160bf8e55e08b34513b Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 26 Sep 2024 15:15:37 +0800 Subject: [PATCH] [SPARK-49784][PYTHON][TESTS] Add more test for `spark.sql` ### What changes were proposed in this pull request? add more test for `spark.sql` ### Why are the changes needed? for test coverage ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48246 from zhengruifeng/py_sql_test. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 2 + .../sql/tests/connect/test_parity_sql.py | 37 ++++ python/pyspark/sql/tests/test_sql.py | 185 ++++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 python/pyspark/sql/tests/connect/test_parity_sql.py create mode 100644 python/pyspark/sql/tests/test_sql.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index eda6b063350e5..d2c000b702a64 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -520,6 +520,7 @@ def __hash__(self): "pyspark.sql.tests.test_errors", "pyspark.sql.tests.test_functions", "pyspark.sql.tests.test_group", + "pyspark.sql.tests.test_sql", "pyspark.sql.tests.pandas.test_pandas_cogrouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", @@ -1032,6 +1033,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_serde", "pyspark.sql.tests.connect.test_parity_functions", "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_sql", "pyspark.sql.tests.connect.test_parity_dataframe", "pyspark.sql.tests.connect.test_parity_collection", "pyspark.sql.tests.connect.test_parity_creation", diff --git a/python/pyspark/sql/tests/connect/test_parity_sql.py b/python/pyspark/sql/tests/connect/test_parity_sql.py new file mode 100644 index 0000000000000..4c6b11c60cbe9 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_sql.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_sql import SQLTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLParityTests(SQLTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_sql.py b/python/pyspark/sql/tests/test_sql.py new file mode 100644 index 0000000000000..bf50bbc11ac33 --- /dev/null +++ b/python/pyspark/sql/tests/test_sql.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SQLTestsMixin: + def test_simple(self): + res = self.spark.sql("SELECT 1 + 1").collect() + self.assertEqual(len(res), 1) + self.assertEqual(res[0][0], 2) + + def test_args_dict(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name)", + args={"table_name": "test"}, + ) + + self.assertEqual(df.count(), 10) + self.assertEqual(df.limit(5).count(), 5) + self.assertEqual(df.offset(5).count(), 5) + + self.assertEqual(df.take(1), [Row(id=0)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_args_list(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM test WHERE ? < id AND id < ?", + args=[1, 6], + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.limit(3).count(), 3) + self.assertEqual(df.offset(3).count(), 1) + + self.assertEqual(df.take(1), [Row(id=2)]) + self.assertEqual(df.tail(1), [Row(id=5)]) + + def test_kwargs_literal(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}", + args={"table_name": "test"}, + m1=3, + m2=7, + m3=9, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_literal_multiple_ref(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0", + args={"table_name": "test"}, + m=6, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=6)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > 4", + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 5) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe_with_column(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2", + {"m1": 4, "m2": 9}, + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 4) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=8)]) + + def test_nested_view(self): + with self.tempView("v1", "v2", "v3", "v4"): + self.spark.range(10).createOrReplaceTempView("v1") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v1", "m": 1}, + ).createOrReplaceTempView("v2") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v2", "m": 2}, + ).createOrReplaceTempView("v3") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v3", "m": 3}, + ).createOrReplaceTempView("v4") + + df = self.spark.sql("select * from v4") + self.assertEqual(df.count(), 6) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_nested_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[1], + df=df0, + ) + df2 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[2], + df=df1, + ) + df3 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[3], + df=df2, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 8) + self.assertEqual(df1.take(1), [Row(id=2)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df2.schema) + self.assertEqual(df2.count(), 7) + self.assertEqual(df2.take(1), [Row(id=3)]) + self.assertEqual(df2.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df3.schema) + self.assertEqual(df3.count(), 6) + self.assertEqual(df3.take(1), [Row(id=4)]) + self.assertEqual(df3.tail(1), [Row(id=9)]) + + +class SQLTests(SQLTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)