Skip to content

Commit

Permalink
[SPARK-49784][PYTHON][TESTS] Add more test for spark.sql
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
add more test for `spark.sql`

### Why are the changes needed?
for test coverage

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48246 from zhengruifeng/py_sql_test.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Sep 26, 2024
1 parent 5629779 commit 913a0f7
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def __hash__(self):
"pyspark.sql.tests.test_errors",
"pyspark.sql.tests.test_functions",
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_sql",
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
"pyspark.sql.tests.pandas.test_pandas_grouped_map",
"pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",
Expand Down Expand Up @@ -1032,6 +1033,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_serde",
"pyspark.sql.tests.connect.test_parity_functions",
"pyspark.sql.tests.connect.test_parity_group",
"pyspark.sql.tests.connect.test_parity_sql",
"pyspark.sql.tests.connect.test_parity_dataframe",
"pyspark.sql.tests.connect.test_parity_collection",
"pyspark.sql.tests.connect.test_parity_creation",
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.test_sql import SQLTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class SQLParityTests(SQLTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_sql import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
185 changes: 185 additions & 0 deletions python/pyspark/sql/tests/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql import Row
from pyspark.testing.sqlutils import ReusedSQLTestCase


class SQLTestsMixin:
def test_simple(self):
res = self.spark.sql("SELECT 1 + 1").collect()
self.assertEqual(len(res), 1)
self.assertEqual(res[0][0], 2)

def test_args_dict(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name)",
args={"table_name": "test"},
)

self.assertEqual(df.count(), 10)
self.assertEqual(df.limit(5).count(), 5)
self.assertEqual(df.offset(5).count(), 5)

self.assertEqual(df.take(1), [Row(id=0)])
self.assertEqual(df.tail(1), [Row(id=9)])

def test_args_list(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")
df = self.spark.sql(
"SELECT * FROM test WHERE ? < id AND id < ?",
args=[1, 6],
)

self.assertEqual(df.count(), 4)
self.assertEqual(df.limit(3).count(), 3)
self.assertEqual(df.offset(3).count(), 1)

self.assertEqual(df.take(1), [Row(id=2)])
self.assertEqual(df.tail(1), [Row(id=5)])

def test_kwargs_literal(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")

df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}",
args={"table_name": "test"},
m1=3,
m2=7,
m3=9,
)

self.assertEqual(df.count(), 4)
self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)])
self.assertEqual(df.take(1), [Row(id=4)])
self.assertEqual(df.tail(1), [Row(id=9)])

def test_kwargs_literal_multiple_ref(self):
with self.tempView("test"):
self.spark.range(10).createOrReplaceTempView("test")

df = self.spark.sql(
"SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0",
args={"table_name": "test"},
m=6,
)

self.assertEqual(df.count(), 4)
self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)])
self.assertEqual(df.take(1), [Row(id=6)])
self.assertEqual(df.tail(1), [Row(id=9)])

def test_kwargs_dataframe(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE id > 4",
df=df0,
)

self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 5)
self.assertEqual(df1.take(1), [Row(id=5)])
self.assertEqual(df1.tail(1), [Row(id=9)])

def test_kwargs_dataframe_with_column(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2",
{"m1": 4, "m2": 9},
df=df0,
)

self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 4)
self.assertEqual(df1.take(1), [Row(id=5)])
self.assertEqual(df1.tail(1), [Row(id=8)])

def test_nested_view(self):
with self.tempView("v1", "v2", "v3", "v4"):
self.spark.range(10).createOrReplaceTempView("v1")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v1", "m": 1},
).createOrReplaceTempView("v2")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v2", "m": 2},
).createOrReplaceTempView("v3")
self.spark.sql(
"SELECT * FROM IDENTIFIER(:view) WHERE id > :m",
args={"view": "v3", "m": 3},
).createOrReplaceTempView("v4")

df = self.spark.sql("select * from v4")
self.assertEqual(df.count(), 6)
self.assertEqual(df.take(1), [Row(id=4)])
self.assertEqual(df.tail(1), [Row(id=9)])

def test_nested_dataframe(self):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[1],
df=df0,
)
df2 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[2],
df=df1,
)
df3 = self.spark.sql(
"SELECT * FROM {df} WHERE id > ?",
args=[3],
df=df2,
)

self.assertEqual(df0.schema, df1.schema)
self.assertEqual(df1.count(), 8)
self.assertEqual(df1.take(1), [Row(id=2)])
self.assertEqual(df1.tail(1), [Row(id=9)])

self.assertEqual(df0.schema, df2.schema)
self.assertEqual(df2.count(), 7)
self.assertEqual(df2.take(1), [Row(id=3)])
self.assertEqual(df2.tail(1), [Row(id=9)])

self.assertEqual(df0.schema, df3.schema)
self.assertEqual(df3.count(), 6)
self.assertEqual(df3.take(1), [Row(id=4)])
self.assertEqual(df3.tail(1), [Row(id=9)])


class SQLTests(SQLTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.sql.tests.test_sql import * # noqa: F401

try:
import xmlrunner # type: ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

0 comments on commit 913a0f7

Please sign in to comment.