Skip to content

Commit 6cff481

Browse files
committed
add more tests and don't run filter pushdown if no supported filter
1 parent d60b704 commit 6cff481

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def __init__(self):
254254
self.has_filter = False
255255

256256
def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
257-
assert len(filters) == 2
258-
assert set(filters) == {EqualTo(("x",), 1), EqualTo(("y",), 2)}
257+
assert len(filters) == 2, filters
258+
assert set(filters) == {EqualTo(("x",), 1), EqualTo(("y",), 2)}, filters
259259
self.has_filter = True
260260
# pretend we support x = 1 filter but in fact we don't
261261
# so we only return y = 2 filter
@@ -293,10 +293,10 @@ def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
293293
yield EqualTo(("",), 1)
294294

295295
def partitions(self):
296-
...
296+
assert False
297297

298298
def read(self, partition):
299-
...
299+
assert False
300300

301301
class TestDataSource(DataSource):
302302
@classmethod
@@ -313,6 +313,55 @@ def reader(self, schema) -> "DataSourceReader":
313313
with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"):
314314
self.spark.read.format("test").load().filter("x = 1").show()
315315

316+
def test_filter_pushdown_error(self):
317+
class TestDataSourceReader(DataSourceReader):
318+
def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
319+
raise Exception("dummy error")
320+
321+
def read(self, partition):
322+
yield [1]
323+
324+
class TestDataSource(DataSource):
325+
@classmethod
326+
def name(cls):
327+
return "test"
328+
329+
def schema(self):
330+
return "x int"
331+
332+
def reader(self, schema) -> "DataSourceReader":
333+
return TestDataSourceReader()
334+
335+
self.spark.dataSource.register(TestDataSource)
336+
df = self.spark.read.format("test").load().filter("cos(x) > 0")
337+
assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters
338+
with self.assertRaisesRegex(Exception, "dummy error"):
339+
df.filter("x = 1").show()
340+
341+
def test_unsupported_filter(self):
342+
class TestDataSourceReader(DataSourceReader):
343+
def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
344+
assert filters == [EqualTo(("x",), 1)], filters
345+
return filters
346+
347+
def read(self, partition):
348+
yield [1, 2, 3]
349+
350+
class TestDataSource(DataSource):
351+
@classmethod
352+
def name(cls):
353+
return "test"
354+
355+
def schema(self):
356+
return "x int, y int, z int"
357+
358+
def reader(self, schema) -> "DataSourceReader":
359+
return TestDataSourceReader()
360+
361+
self.spark.dataSource.register(TestDataSource)
362+
df = self.spark.read.format("test").load().filter("x = 1 and y = z")
363+
assertDataFrameEqual(df, [])
364+
316365
def _get_test_json_data_source(self):
317366
import json
318367
import os

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,15 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
8080
def pushdownFiltersInPython(
8181
pythonResult: PythonDataSourceReader,
8282
filters: Array[Filter]): PythonFilterPushdownResult = {
83-
new UserDefinedPythonDataSourceFilterPushdownRunner(
83+
val runner = new UserDefinedPythonDataSourceFilterPushdownRunner(
8484
createPythonFunction(pythonResult.reader),
8585
filters
86-
).runInPython()
86+
)
87+
if (runner.isAnyFilterSupported) {
88+
runner.runInPython()
89+
} else {
90+
PythonFilterPushdownResult(pythonResult, filters.map(_ => false))
91+
}
8792
}
8893

8994
/**
@@ -409,6 +414,8 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
409414
// See the logic in `pyspark.sql.worker.data_source_pushdown_filters.py`.
410415
override val workerModule = "pyspark.sql.worker.data_source_pushdown_filters"
411416

417+
def isAnyFilterSupported: Boolean = !serializedFilters.isEmpty
418+
412419
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
413420
// Send Python data source
414421
PythonWorkerUtils.writePythonFunction(reader, dataOut)

0 commit comments

Comments
 (0)