Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 150 additions & 4 deletions python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ class Filter(ABC):
See `Data Types <https://spark.apache.org/docs/latest/sql-ref-datatypes.html>`_
for more information about how values are represented in Python.

Currently only the equality of attribute and literal value is supported for
filter pushdown. Other types of filters cannot be pushed down.

Examples
--------
Supported filters
Expand All @@ -293,25 +290,174 @@ class Filter(ABC):
| SQL filter | Representation |
+---------------------+--------------------------------------------+
| `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` |
| `a = 1` | `EqualTo(("a", "b", "c"), 1)` |
| `a = 1` | `EqualTo(("a",), 1)` |
| `a = 'hi'` | `EqualTo(("a",), "hi")` |
| `a = array(1, 2)` | `EqualTo(("a",), [1, 2])` |
| `a` | `EqualTo(("a",), True)` |
| `not a` | `Not(EqualTo(("a",), True))` |
| `a <> 1` | `Not(EqualTo(("a",), 1))` |
| `a > 1` | `GreaterThan(("a",), 1)` |
| `a >= 1` | `GreaterThanOrEqual(("a",), 1)` |
| `a < 1` | `LessThan(("a",), 1)` |
| `a <= 1` | `LessThanOrEqual(("a",), 1)` |
| `a in (1, 2, 3)` | `In(("a",), (1, 2, 3))` |
| `a is null` | `IsNull(("a",))` |
| `a is not null` | `IsNotNull(("a",))` |
| `a like 'abc%'` | `StringStartsWith(("a",), "abc")` |
| `a like '%abc'` | `StringEndsWith(("a",), "abc")` |
| `a like '%abc%'` | `StringContains(("a",), "abc")` |
+---------------------+--------------------------------------------+

Unsupported filters
- `a = b`
- `f(a, b) = 1`
- `a % 2 = 1`
- `a[0] = 1`
- `a < 0 or a > 1`
- `a like 'c%c%'`
- `a ilike 'hi'`
- `a = 'hi' collate zh`
"""


@dataclass(frozen=True)
class EqualTo(Filter):
"""
A filter that evaluates to `True` iff the column evaluates to a value
equal to `value`.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class EqualNullSafe(Filter):
"""
Performs equality comparison, similar to EqualTo. However, this differs from EqualTo
in that it returns `true` (rather than NULL) if both inputs are NULL, and `false`
(rather than NULL) if one of the input is NULL and the other is not NULL.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class GreaterThan(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to a value
greater than `value`.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class GreaterThanOrEqual(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to a value
greater than or equal to `value`.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class LessThan(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to a value
less than `value`.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class LessThanOrEqual(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to a value
less than or equal to `value`.
"""

attribute: ColumnPath
value: Any


@dataclass(frozen=True)
class In(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to one of the values
in the array.
"""

attribute: ColumnPath
value: Tuple[Any, ...]


@dataclass(frozen=True)
class IsNull(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to null.
"""

attribute: ColumnPath


@dataclass(frozen=True)
class IsNotNull(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to a non-null value.
"""

attribute: ColumnPath


@dataclass(frozen=True)
class Not(Filter):
"""
A filter that evaluates to `True` iff `child` is evaluated to `False`.
"""

child: Filter


@dataclass(frozen=True)
class StringStartsWith(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to
a string that starts with `value`.
"""

attribute: ColumnPath
value: str


@dataclass(frozen=True)
class StringEndsWith(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to
a string that ends with `value`.
"""

attribute: ColumnPath
value: str


@dataclass(frozen=True)
class StringContains(Filter):
"""
A filter that evaluates to `True` iff the attribute evaluates to
a string that contains the string `value`.
"""

attribute: ColumnPath
value: str


class InputPartition:
"""
A base class representing an input partition returned by the `partitions()`
Expand Down
81 changes: 72 additions & 9 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,44 @@
import platform
import tempfile
import unittest
from datetime import datetime
from decimal import Decimal
from typing import Callable, Iterable, List, Union

from pyspark.errors import PythonException, AnalysisException
from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.datasource import (
CaseInsensitiveDict,
DataSource,
DataSourceArrowWriter,
DataSourceReader,
DataSourceWriter,
EqualNullSafe,
EqualTo,
Filter,
GreaterThan,
GreaterThanOrEqual,
In,
InputPartition,
DataSourceWriter,
DataSourceArrowWriter,
IsNotNull,
IsNull,
LessThan,
LessThanOrEqual,
Not,
StringContains,
StringEndsWith,
StringStartsWith,
WriterCommitMessage,
CaseInsensitiveDict,
)
from pyspark.sql.functions import spark_partition_id
from pyspark.sql.session import SparkSession
from pyspark.sql.types import Row, StructType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
SPARK_HOME,
ReusedSQLTestCase,
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase, SPARK_HOME


@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
Expand Down Expand Up @@ -258,6 +273,8 @@ def __init__(self):

def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
assert set(filters) == {
IsNotNull(("x",)),
IsNotNull(("y",)),
EqualTo(("x",), 1),
EqualTo(("y",), 2),
}, filters
Expand Down Expand Up @@ -376,8 +393,9 @@ def _check_filters(self, sql_type, sql_filter, python_filters):

class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
actual = [f for f in filters if not isinstance(f, IsNotNull)]
expected = python_filters
assert filters == expected, (filters, expected)
assert actual == expected, (actual, expected)
return filters

def read(self, partition):
Expand All @@ -399,12 +417,57 @@ def test_unsupported_filter(self):
self._check_filters(
"struct<a:int, b:int, c:int>", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)]
)
self._check_filters("int", "x <> 0", [])
self._check_filters("int", "x = 1 or x > 2", [])
self._check_filters("int", "(0 < x and x < 1) or x = 2", [])
self._check_filters("int", "x % 5 = 1", [])
self._check_filters("boolean", "not x", [])
self._check_filters("array<int>", "x[0] = 1", [])
self._check_filters("string", "x like 'a%a%'", [])
self._check_filters("string", "x ilike 'a'", [])
self._check_filters("string", "x = 'a' collate zh", [])

def test_filter_value_type(self):
self._check_filters("int", "x = 1", [EqualTo(("x",), 1)])
self._check_filters("int", "x = null", [EqualTo(("x",), None)])
self._check_filters("float", "x = 3 / 2", [EqualTo(("x",), 1.5)])
self._check_filters("string", "x = '1'", [EqualTo(("x",), "1")])
self._check_filters("array<int>", "x = array(1, 2)", [EqualTo(("x",), [1, 2])])
self._check_filters(
"struct<x:int>", "x = named_struct('x', 1)", [EqualTo(("x",), {"x": 1})]
)
self._check_filters(
"decimal", "x in (1.1, 2.1)", [In(("x",), [Decimal(1.1), Decimal(2.1)])]
)
self._check_filters(
"timestamp_ntz",
"x = timestamp_ntz '2020-01-01 00:00:00'",
[EqualTo(("x",), datetime.strptime("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"))],
)
self._check_filters(
"interval second",
"x = interval '2' second",
[], # intervals are not supported
)

def test_filter_type(self):
self._check_filters("boolean", "x", [EqualTo(("x",), True)])
self._check_filters("boolean", "not x", [Not(EqualTo(("x",), True))])
self._check_filters("int", "x is null", [IsNull(("x",))])
self._check_filters("int", "x <> 0", [Not(EqualTo(("x",), 0))])
self._check_filters("int", "x <=> 1", [EqualNullSafe(("x",), 1)])
self._check_filters("int", "1 < x", [GreaterThan(("x",), 1)])
self._check_filters("int", "1 <= x", [GreaterThanOrEqual(("x",), 1)])
self._check_filters("int", "x < 1", [LessThan(("x",), 1)])
self._check_filters("int", "x <= 1", [LessThanOrEqual(("x",), 1)])
self._check_filters("string", "x like 'a%'", [StringStartsWith(("x",), "a")])
self._check_filters("string", "x like '%a'", [StringEndsWith(("x",), "a")])
self._check_filters("string", "x like '%a%'", [StringContains(("x",), "a")])
self._check_filters(
"string", "x like 'a%b'", [StringStartsWith(("x",), "a"), StringEndsWith(("x",), "b")]
)
self._check_filters("int", "x in (1, 2)", [In(("x",), [1, 2])])

def test_filter_nested_column(self):
self._check_filters("struct<y:int>", "x.y = 1", [EqualTo(("x", "y"), 1)])

def _get_test_json_data_source(self):
import json
Expand Down
Loading