Skip to content

Commit

Permalink
skip pd.NA similar to np.nan / None
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Oct 21, 2023
1 parent 1ccb012 commit ac83d6a
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 29 deletions.
2 changes: 1 addition & 1 deletion extern/rapidfuzz-cpp
6 changes: 5 additions & 1 deletion src/rapidfuzz/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

from rapidfuzz._feature_detector import AVX2, SSE2, supports

try:
from pandas import NA as pandas_NA
except:
pandas_NA = None

class ScorerFlag(IntFlag):
RESULT_F64 = 1 << 5
Expand Down Expand Up @@ -51,7 +55,7 @@ def _get_scorer_flags_normalized_similarity(**_kwargs: Any) -> dict[str, Any]:


def is_none(s: Any) -> bool:
if s is None:
if s is None or s is pandas_NA:
return True

if isinstance(s, float) and isnan(s):
Expand Down
1 change: 0 additions & 1 deletion src/rapidfuzz/cpp_common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ from rapidfuzz cimport (

from array import array


cdef extern from "rapidfuzz/details/types.hpp" namespace "rapidfuzz" nogil:
cpdef enum class EditType:
None = 0,
Expand Down
7 changes: 6 additions & 1 deletion src/rapidfuzz/process_cpp_impl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ from rapidfuzz cimport (
)


try:
from pandas import NA as pandas_NA
except:
pandas_NA = None

cdef extern from "process_cpp.hpp":
cdef cppclass ExtractComp:
ExtractComp()
Expand Down Expand Up @@ -129,7 +134,7 @@ cdef extern from "process_cpp.hpp":
const vector[RF_StringWrapper]&, const vector[RF_StringWrapper]&, MatrixType, int, T, T, T) except +

cdef inline bool is_none(s):
if s is None:
if s is None or s is pandas_NA:
return True

if isinstance(s, float) and isnan(<double>s):
Expand Down
25 changes: 7 additions & 18 deletions src/rapidfuzz/process_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
overload,
)

from rapidfuzz._utils import ScorerFlag
from rapidfuzz._utils import ScorerFlag, is_none
from rapidfuzz.fuzz import WRatio, ratio

__all__ = ["extract", "extract_iter", "extractOne", "cdist"]
Expand All @@ -30,17 +30,6 @@ def _get_scorer_flags_py(scorer: Any, scorer_kwargs: dict[str, Any]) -> tuple[in
return (flags["worst_score"], flags["optimal_score"])
return (0, 100)


def _is_none(s: Any) -> bool:
if s is None:
return True

if isinstance(s, float) and isnan(s):
return True

return False


@overload
def extract_iter(
query: Sequence[Hashable] | None,
Expand Down Expand Up @@ -141,7 +130,7 @@ def extract_iter(
worst_score, optimal_score = _get_scorer_flags_py(scorer, scorer_kwargs)
lowest_score_worst = optimal_score > worst_score

if _is_none(query):
if is_none(query):
return

if score_cutoff is None:
Expand All @@ -154,7 +143,7 @@ def extract_iter(
choices_iter: Iterable[tuple[Any, Sequence[Hashable] | None]]
choices_iter = choices.items() if hasattr(choices, "items") else enumerate(choices) # type: ignore[union-attr]
for key, choice in choices_iter:
if _is_none(choice):
if is_none(choice):
continue

if processor is None:
Expand Down Expand Up @@ -334,7 +323,7 @@ def extractOne(
worst_score, optimal_score = _get_scorer_flags_py(scorer, scorer_kwargs)
lowest_score_worst = optimal_score > worst_score

if _is_none(query):
if is_none(query):
return None

if score_cutoff is None:
Expand All @@ -349,7 +338,7 @@ def extractOne(
choices_iter: Iterable[tuple[Any, Sequence[Hashable] | None]]
choices_iter = choices.items() if hasattr(choices, "items") else enumerate(choices) # type: ignore[union-attr]
for key, choice in choices_iter:
if _is_none(choice):
if is_none(choice):
continue

if processor is None:
Expand Down Expand Up @@ -611,7 +600,7 @@ def cdist(
if processor is None:
proc_choices = list(choices)
else:
proc_choices = [x if _is_none(x) else processor(x) for x in choices]
proc_choices = [x if is_none(x) else processor(x) for x in choices]

if queries is choices and _is_symmetric(scorer, scorer_kwargs):
for i, query in enumerate(proc_choices):
Expand All @@ -625,7 +614,7 @@ def cdist(
)
else:
for i, query in enumerate(queries):
proc_query = processor(query) if (processor and not _is_none(query)) else query
proc_query = processor(query) if (processor and not is_none(query)) else query
for j, choice in enumerate(proc_choices):
results[i, j] = scorer(
proc_query,
Expand Down
6 changes: 5 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

from rapidfuzz import process_cpp, process_py

try:
from pandas import NA as pandas_NA
except:
pandas_NA = None

def _get_scorer_flags_py(scorer: Any, scorer_kwargs: dict[str, Any]) -> tuple[int, int]:
params = getattr(scorer, "_RF_ScorerPy", None)
Expand All @@ -21,7 +25,7 @@ def _get_scorer_flags_py(scorer: Any, scorer_kwargs: dict[str, Any]) -> tuple[in


def is_none(s):
if s is None:
if s is None or s is pandas_NA:
return True

if isinstance(s, float) and isnan(s):
Expand Down
49 changes: 43 additions & 6 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from rapidfuzz import fuzz, process_cpp, process_py
from rapidfuzz.distance import Levenshtein, Levenshtein_py

with suppress(BaseException):
import numpy as np


def wrapped(func):
from functools import wraps

Expand Down Expand Up @@ -47,6 +43,7 @@ def extract(*args, **kwargs):

@staticmethod
def cdist(*args, **kwargs):
import numpy as np
res1 = process_cpp.cdist(*args, **kwargs)
res2 = process_py.cdist(*args, **kwargs)
assert res1.dtype == res2.dtype
Expand Down Expand Up @@ -292,6 +289,38 @@ def test_none_elements():
assert best == []


def test_numpy_nan_elements():
"""
when a np.nan element is used, it is skipped and the index is still correct
"""
np = pytest.importorskip("numpy")
best = process.extractOne("test", [np.nan, "tes"])
assert best[2] == 1
best = process.extractOne(np.nan, [np.nan, "tes"])
assert best is None

best = process.extract("test", [np.nan, "tes"])
assert best[0][2] == 1
best = process.extract(np.nan, [np.nan, "tes"])
assert best == []


def test_pandas_nan_elements():
"""
when a pd.NA element is used, it is skipped and the index is still correct
"""
pd = pytest.importorskip("pandas")
best = process.extractOne("test", [pd.NA, "tes"])
assert best[2] == 1
best = process.extractOne(pd.NA, [pd.NA, "tes"])
assert best is None

best = process.extract("test", [pd.NA, "tes"])
assert best[0][2] == 1
best = process.extract(pd.NA, [pd.NA, "tes"])
assert best == []


def test_result_order():
"""
when multiple elements have the same score, the first one should be returned
Expand Down Expand Up @@ -405,9 +434,17 @@ def test_wrapped_function(scorer):
assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100
assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100

try:
import pandas as pd
except Exception:
pd = None

if pd is not None:
assert process.cdist(["test"], [pd.NA], scorer=scorer)[0, 0] == 100


def test_cdist_not_symmetric():
pytest.importorskip("numpy")
np = pytest.importorskip("numpy")
strings = ["test", "test2"]
expected_res = np.array([[0, 1], [2, 0]])
assert np.array_equal(
Expand All @@ -434,7 +471,7 @@ def generate_choices():


def test_cdist_pure_python_dtype():
pytest.importorskip("numpy")
np = pytest.importorskip("numpy")
assert process.cdist(["test"], ["test"], scorer=Levenshtein_py.distance).dtype == np.int32
assert process.cdist(["test"], ["test"], scorer=Levenshtein_py.similarity).dtype == np.int32
assert process.cdist(["test"], ["test"], scorer=Levenshtein_py.normalized_distance).dtype == np.float32
Expand Down

0 comments on commit ac83d6a

Please sign in to comment.