@@ -254,8 +254,8 @@ def __init__(self):
254
254
self .has_filter = False
255
255
256
256
def pushdownFilters (self , filters : List [Filter ]) -> Iterable [Filter ]:
257
- assert len (filters ) == 2
258
- assert set (filters ) == {EqualTo (("x" ,), 1 ), EqualTo (("y" ,), 2 )}
257
+ assert len (filters ) == 2 , filters
258
+ assert set (filters ) == {EqualTo (("x" ,), 1 ), EqualTo (("y" ,), 2 )}, filters
259
259
self .has_filter = True
260
260
# pretend we support x = 1 filter but in fact we don't
261
261
# so we only return y = 2 filter
@@ -293,10 +293,10 @@ def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
293
293
yield EqualTo (("" ,), 1 )
294
294
295
295
def partitions (self ):
296
- ...
296
+ assert False
297
297
298
298
def read (self , partition ):
299
- ...
299
+ assert False
300
300
301
301
class TestDataSource (DataSource ):
302
302
@classmethod
@@ -313,6 +313,55 @@ def reader(self, schema) -> "DataSourceReader":
313
313
with self .assertRaisesRegex (Exception , "DATA_SOURCE_EXTRANEOUS_FILTERS" ):
314
314
self .spark .read .format ("test" ).load ().filter ("x = 1" ).show ()
315
315
316
+ def test_filter_pushdown_error (self ):
317
+ class TestDataSourceReader (DataSourceReader ):
318
+ def pushdownFilters (self , filters : List [Filter ]) -> Iterable [Filter ]:
319
+ raise Exception ("dummy error" )
320
+
321
+ def read (self , partition ):
322
+ yield [1 ]
323
+
324
+ class TestDataSource (DataSource ):
325
+ @classmethod
326
+ def name (cls ):
327
+ return "test"
328
+
329
+ def schema (self ):
330
+ return "x int"
331
+
332
+ def reader (self , schema ) -> "DataSourceReader" :
333
+ return TestDataSourceReader ()
334
+
335
+ self .spark .dataSource .register (TestDataSource )
336
+ df = self .spark .read .format ("test" ).load ().filter ("cos(x) > 0" )
337
+ assertDataFrameEqual (df , [Row (x = 1 )]) # works when not pushing down filters
338
+ with self .assertRaisesRegex (Exception , "dummy error" ):
339
+ df .filter ("x = 1" ).show ()
340
+
341
+ def test_unsupported_filter (self ):
342
+ class TestDataSourceReader (DataSourceReader ):
343
+ def pushdownFilters (self , filters : List [Filter ]) -> Iterable [Filter ]:
344
+ assert filters == [EqualTo (("x" ,), 1 )], filters
345
+ return filters
346
+
347
+ def read (self , partition ):
348
+ yield [1 , 2 , 3 ]
349
+
350
+ class TestDataSource (DataSource ):
351
+ @classmethod
352
+ def name (cls ):
353
+ return "test"
354
+
355
+ def schema (self ):
356
+ return "x int, y int, z int"
357
+
358
+ def reader (self , schema ) -> "DataSourceReader" :
359
+ return TestDataSourceReader ()
360
+
361
+ self .spark .dataSource .register (TestDataSource )
362
+ df = self .spark .read .format ("test" ).load ().filter ("x = 1 and y = z" )
363
+ assertDataFrameEqual (df , [])
364
+
316
365
def _get_test_json_data_source (self ):
317
366
import json
318
367
import os
0 commit comments