Skip to content

Commit 2bda8e3

Browse files
committed
WIP filter serialization
1 parent 2c584a1 commit 2bda8e3

File tree

5 files changed

+273
-42
lines changed

5 files changed

+273
-42
lines changed

python/pyspark/sql/datasource.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,70 @@ class EqualTo(Filter):
312312
value: Any
313313

314314

315+
@dataclass(frozen=True)
316+
class EqualNullSafe(Filter):
317+
attribute: ColumnPath
318+
value: Any
319+
320+
321+
@dataclass(frozen=True)
322+
class GreaterThan(Filter):
323+
attribute: ColumnPath
324+
value: Any
325+
326+
327+
@dataclass(frozen=True)
328+
class GreaterThanOrEqual(Filter):
329+
attribute: ColumnPath
330+
value: Any
331+
332+
333+
@dataclass(frozen=True)
334+
class LessThan(Filter):
335+
attribute: ColumnPath
336+
value: Any
337+
338+
339+
@dataclass(frozen=True)
340+
class LessThanOrEqual(Filter):
341+
attribute: ColumnPath
342+
value: Any
343+
344+
345+
@dataclass(frozen=True)
346+
class In(Filter):
347+
attribute: ColumnPath
348+
value: Tuple[Any, ...]
349+
350+
351+
@dataclass(frozen=True)
352+
class IsNull(Filter):
353+
attribute: ColumnPath
354+
355+
356+
@dataclass(frozen=True)
357+
class IsNotNull(Filter):
358+
attribute: ColumnPath
359+
360+
361+
@dataclass(frozen=True)
362+
class StringStartsWith(Filter):
363+
attribute: ColumnPath
364+
value: str
365+
366+
367+
@dataclass(frozen=True)
368+
class StringEndsWith(Filter):
369+
attribute: ColumnPath
370+
value: str
371+
372+
373+
@dataclass(frozen=True)
374+
class StringContains(Filter):
375+
attribute: ColumnPath
376+
value: str
377+
378+
315379
class InputPartition:
316380
"""
317381
A base class representing an input partition returned by the `partitions()`

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,43 @@
1818
import platform
1919
import tempfile
2020
import unittest
21+
from datetime import datetime
22+
from decimal import Decimal
2123
from typing import Callable, Iterable, List, Union
2224

23-
from pyspark.errors import PythonException, AnalysisException
25+
from pyspark.errors import AnalysisException, PythonException
2426
from pyspark.sql.datasource import (
27+
CaseInsensitiveDict,
2528
DataSource,
29+
DataSourceArrowWriter,
2630
DataSourceReader,
31+
DataSourceWriter,
32+
EqualNullSafe,
2733
EqualTo,
2834
Filter,
35+
GreaterThan,
36+
GreaterThanOrEqual,
37+
In,
2938
InputPartition,
30-
DataSourceWriter,
31-
DataSourceArrowWriter,
39+
IsNotNull,
40+
IsNull,
41+
LessThan,
42+
LessThanOrEqual,
43+
StringContains,
44+
StringEndsWith,
45+
StringStartsWith,
3246
WriterCommitMessage,
33-
CaseInsensitiveDict,
3447
)
3548
from pyspark.sql.functions import spark_partition_id
3649
from pyspark.sql.session import SparkSession
3750
from pyspark.sql.types import Row, StructType
51+
from pyspark.testing import assertDataFrameEqual
3852
from pyspark.testing.sqlutils import (
53+
SPARK_HOME,
54+
ReusedSQLTestCase,
3955
have_pyarrow,
4056
pyarrow_requirement_message,
4157
)
42-
from pyspark.testing import assertDataFrameEqual
43-
from pyspark.testing.sqlutils import ReusedSQLTestCase, SPARK_HOME
4458

4559

4660
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -258,6 +272,8 @@ def __init__(self):
258272

259273
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
260274
assert set(filters) == {
275+
IsNotNull(("x",)),
276+
IsNotNull(("y",)),
261277
EqualTo(("x",), 1),
262278
EqualTo(("y",), 2),
263279
}, filters
@@ -376,8 +392,9 @@ def _check_filters(self, sql_type, sql_filter, python_filters):
376392

377393
class TestDataSourceReader(DataSourceReader):
378394
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
395+
actual = [f for f in filters if not isinstance(f, IsNotNull)]
379396
expected = python_filters
380-
assert filters == expected, (filters, expected)
397+
assert actual == expected, (actual, expected)
381398
return filters
382399

383400
def read(self, partition):
@@ -406,6 +423,45 @@ def test_unsupported_filter(self):
406423
self._check_filters("boolean", "not x", [])
407424
self._check_filters("array<int>", "x[0] = 1", [])
408425

426+
def test_filter_value_type(self):
427+
self._check_filters("int", "x = 1", [EqualTo(("x",), 1)])
428+
self._check_filters("int", "x = null", [EqualTo(("x",), None)])
429+
self._check_filters("float", "x = 3 / 2", [EqualTo(("x",), 1.5)])
430+
self._check_filters("string", "x = '1'", [EqualTo(("x",), "1")])
431+
self._check_filters("array<int>", "x = array(1, 2)", [EqualTo(("x",), [1, 2])])
432+
self._check_filters(
433+
"struct<x:int>", "x = named_struct('x', 1)", [EqualTo(("x",), {"x": 1})]
434+
)
435+
self._check_filters(
436+
"decimal", "x in (1.1, 2.1)", [In(("x",), [Decimal(1.1), Decimal(2.1)])]
437+
)
438+
self._check_filters(
439+
"timestamp_ntz",
440+
"x = timestamp_ntz '2020-01-01 00:00:00'",
441+
[EqualTo(("x",), datetime.strptime("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"))],
442+
)
443+
self._check_filters(
444+
"interval second",
445+
"x = interval '2' second",
446+
[], # intervals are not supported
447+
)
448+
449+
def test_filter_type(self):
450+
self._check_filters("boolean", "x", [EqualTo(("x",), True)])
451+
self._check_filters("int", "x is null", [IsNull(("x",))])
452+
self._check_filters("int", "x <=> 1", [EqualNullSafe(("x",), 1)])
453+
self._check_filters("int", "1 < x", [GreaterThan(("x",), 1)])
454+
self._check_filters("int", "1 <= x", [GreaterThanOrEqual(("x",), 1)])
455+
self._check_filters("int", "x < 1", [LessThan(("x",), 1)])
456+
self._check_filters("int", "x <= 1", [LessThanOrEqual(("x",), 1)])
457+
self._check_filters("string", "startswith(x, 'a')", [StringStartsWith(("x",), "a")])
458+
self._check_filters("string", "endswith(x, 'a')", [StringEndsWith(("x",), "a")])
459+
self._check_filters("string", "contains(x, 'a')", [StringContains(("x",), "a")])
460+
self._check_filters("int", "x in (1, 2)", [In(("x",), [1, 2])])
461+
462+
def test_filter_nested_column(self):
463+
self._check_filters("struct<y:int>", "x.y = 1", [EqualTo(("x", "y"), 1)])
464+
409465
def _get_test_json_data_source(self):
410466
import json
411467
import os

python/pyspark/sql/worker/data_source_pushdown_filters.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,37 @@
1515
# limitations under the License.
1616
#
1717

18+
import base64
1819
import faulthandler
20+
import json
1921
import os
2022
import sys
23+
import typing
2124
from dataclasses import dataclass, field
22-
from typing import IO, List
25+
from typing import IO, Type, Union
2326

2427
from pyspark.accumulators import _accumulatorRegistry
2528
from pyspark.errors import PySparkAssertionError, PySparkValueError
29+
from pyspark.errors.exceptions.base import PySparkNotImplementedError
2630
from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, write_int
27-
from pyspark.sql.datasource import DataSource, DataSourceReader, EqualTo, Filter
28-
from pyspark.sql.types import StructType, _parse_datatype_json_string
31+
from pyspark.sql.datasource import (
32+
DataSource,
33+
DataSourceReader,
34+
EqualNullSafe,
35+
EqualTo,
36+
Filter,
37+
GreaterThan,
38+
GreaterThanOrEqual,
39+
In,
40+
IsNotNull,
41+
IsNull,
42+
LessThan,
43+
LessThanOrEqual,
44+
StringContains,
45+
StringEndsWith,
46+
StringStartsWith,
47+
)
48+
from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
2949
from pyspark.util import handle_worker_exception, local_connect_and_auth
3050
from pyspark.worker_util import (
3151
check_python_version,
@@ -39,6 +59,25 @@
3959

4060
utf8_deserializer = UTF8Deserializer()
4161

62+
BinaryFilter = Union[
63+
EqualTo,
64+
EqualNullSafe,
65+
GreaterThan,
66+
GreaterThanOrEqual,
67+
LessThan,
68+
LessThanOrEqual,
69+
In,
70+
StringStartsWith,
71+
StringEndsWith,
72+
StringContains,
73+
]
74+
75+
binary_filters = {cls.__name__: cls for cls in typing.get_args(BinaryFilter)}
76+
77+
UnaryFilter = Union[IsNotNull, IsNull]
78+
79+
unary_filters = {cls.__name__: cls for cls in typing.get_args(UnaryFilter)}
80+
4281

4382
@dataclass(frozen=True)
4483
class FilterRef:
@@ -49,6 +88,30 @@ def __post_init__(self) -> None:
4988
object.__setattr__(self, "id", id(self.filter))
5089

5190

91+
def deserializeVariant(variantDict: dict) -> VariantVal:
92+
value = base64.b64decode(variantDict["value"])
93+
metadata = base64.b64decode(variantDict["metadata"])
94+
return VariantVal(value, metadata)
95+
96+
97+
def deserializeFilter(jsonDict: dict) -> Filter:
98+
name = jsonDict["name"]
99+
if name in binary_filters:
100+
binary_filter_cls: Type[BinaryFilter] = binary_filters[name]
101+
return binary_filter_cls(
102+
attribute=tuple(jsonDict["columnPath"]),
103+
value=deserializeVariant(jsonDict["value"]).toPython(),
104+
)
105+
elif name in unary_filters:
106+
unary_filter_cls: Type[UnaryFilter] = unary_filters[name]
107+
return unary_filter_cls(attribute=tuple(jsonDict["columnPath"]))
108+
else:
109+
raise PySparkNotImplementedError(
110+
errorClass="UNSUPPORTED_FILTER",
111+
messageParameters={"name": name},
112+
)
113+
114+
52115
def main(infile: IO, outfile: IO) -> None:
53116
"""
54117
Main method for planning a data source read with filter pushdown.
@@ -126,22 +189,9 @@ def main(infile: IO, outfile: IO) -> None:
126189
)
127190

128191
# Receive the pushdown filters.
129-
num_filters = read_int(infile)
130-
filters: List[FilterRef] = []
131-
for _ in range(num_filters):
132-
name = utf8_deserializer.loads(infile)
133-
if name == "EqualTo":
134-
num_parts = read_int(infile)
135-
column_path = tuple(utf8_deserializer.loads(infile) for _ in range(num_parts))
136-
value = read_int(infile)
137-
filters.append(FilterRef(EqualTo(column_path, value)))
138-
else:
139-
raise PySparkAssertionError(
140-
errorClass="DATA_SOURCE_UNSUPPORTED_FILTER",
141-
messageParameters={
142-
"name": name,
143-
},
144-
)
192+
json_str = utf8_deserializer.loads(infile)
193+
filter_dicts = json.loads(json_str)
194+
filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
145195

146196
# Push down the filters and get the indices of the unsupported filters.
147197
unsupported_filters = set(

0 commit comments

Comments
 (0)