Skip to content

Commit 11fe33f

Browse files
feat: Add {Expr,Series}.is_close (#2962)
--------- Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
1 parent c441460 commit 11fe33f

File tree

11 files changed

+581
-8
lines changed

11 files changed

+581
-8
lines changed

docs/api-reference/expr.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- filter
2424
- clip
2525
- is_between
26+
- is_close
2627
- is_duplicated
2728
- is_finite
2829
- is_first_distinct

docs/api-reference/series.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
- hist
3636
- implementation
3737
- is_between
38+
- is_close
3839
- is_duplicated
3940
- is_empty
4041
- is_finite

narwhals/_compliant/expr.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
LazyExprT,
2828
NativeExprT,
2929
)
30+
from narwhals._compliant.utils import IsClose
3031
from narwhals._utils import _StoresCompliant
3132
from narwhals.dependencies import get_numpy, is_numpy_array
3233

@@ -75,7 +76,7 @@ def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override]
7576
def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override]
7677

7778

78-
class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
79+
class CompliantExpr(IsClose, Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
7980
_implementation: Implementation
8081
_version: Version
8182
_evaluate_output_names: EvalNames[CompliantFrameT]
@@ -902,6 +903,22 @@ def exp(self) -> Self:
902903
def sqrt(self) -> Self:
903904
return self._reuse_series("sqrt")
904905

906+
def is_close(
907+
self,
908+
other: Self | NumericLiteral,
909+
*,
910+
abs_tol: float,
911+
rel_tol: float,
912+
nans_equal: bool,
913+
) -> Self:
914+
return self._reuse_series(
915+
"is_close",
916+
other=other,
917+
abs_tol=abs_tol,
918+
rel_tol=rel_tol,
919+
nans_equal=nans_equal,
920+
)
921+
905922
@property
906923
def cat(self) -> EagerExprCatNamespace[Self]:
907924
return EagerExprCatNamespace(self)

narwhals/_compliant/series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NativeSeriesT,
1717
NativeSeriesT_co,
1818
)
19+
from narwhals._compliant.utils import IsClose
1920
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
2021
from narwhals._typing_compat import TypeVar, assert_never
2122
from narwhals._utils import (
@@ -78,6 +79,7 @@ class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
7879

7980

8081
class CompliantSeries(
82+
IsClose,
8183
NumpyConvertible["_1DArray", "Into1DArray"],
8284
FromIterable,
8385
FromNative[NativeSeriesT],

narwhals/_compliant/utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol
4+
5+
if TYPE_CHECKING:
6+
from typing_extensions import Self
7+
8+
from narwhals.typing import NumericLiteral, TemporalLiteral
9+
10+
11+
class IsClose(Protocol):
12+
"""Every member defined is a dependency of `is_close` method."""
13+
14+
def __and__(self, other: Any) -> Self: ...
15+
def __or__(self, other: Any) -> Self: ...
16+
def __invert__(self) -> Self: ...
17+
def __sub__(self, other: Any) -> Self: ...
18+
def __mul__(self, other: Any) -> Self: ...
19+
def __eq__(self, other: Self | Any) -> Self: ... # type: ignore[override]
20+
def __gt__(self, other: Any) -> Self: ...
21+
def __le__(self, other: Any) -> Self: ...
22+
def abs(self) -> Self: ...
23+
def is_nan(self) -> Self: ...
24+
def is_finite(self) -> Self: ...
25+
def clip(
26+
self,
27+
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
28+
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
29+
) -> Self: ...
30+
def is_close(
31+
self,
32+
other: Self | NumericLiteral,
33+
*,
34+
abs_tol: float,
35+
rel_tol: float,
36+
nans_equal: bool,
37+
) -> Self:
38+
from decimal import Decimal
39+
40+
other_abs: Self | NumericLiteral
41+
other_is_nan: Self | bool
42+
other_is_inf: Self | bool
43+
other_is_not_inf: Self | bool
44+
45+
if isinstance(other, (float, int, Decimal)):
46+
from math import isinf, isnan
47+
48+
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
49+
other_abs = other.__abs__()
50+
other_is_nan = isnan(other)
51+
other_is_inf = isinf(other)
52+
53+
# Define the other_is_not_inf variable to prevent triggering the following warning:
54+
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
55+
# > removed in Python 3.16.
56+
other_is_not_inf = not other_is_inf
57+
58+
else:
59+
other_abs, other_is_nan = other.abs(), other.is_nan()
60+
other_is_not_inf = other.is_finite() | other_is_nan
61+
other_is_inf = ~other_is_not_inf
62+
63+
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
64+
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
65+
66+
self_is_nan = self.is_nan()
67+
self_is_not_inf = self.is_finite() | self_is_nan
68+
69+
# Values are close if abs_diff <= tolerance, and both finite
70+
is_close = (
71+
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
72+
)
73+
74+
# Handle infinity cases: infinities are close/equal if they have the same sign
75+
self_sign, other_sign = self > 0, other > 0
76+
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
77+
78+
# Handle nan cases:
79+
# * If any value is NaN, then False (via `& ~either_nan`)
80+
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
81+
either_nan = self_is_nan | other_is_nan
82+
result = (is_close | is_same_inf) & ~either_nan
83+
84+
if nans_equal:
85+
both_nan = self_is_nan & other_is_nan
86+
result = result | both_nan
87+
88+
return result

narwhals/_polars/expr.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from narwhals._polars.dataframe import Method, PolarsDataFrame
2828
from narwhals._polars.namespace import PolarsNamespace
2929
from narwhals._utils import Version, _LimitedContext
30-
from narwhals.typing import IntoDType
30+
from narwhals.typing import IntoDType, NumericLiteral
3131

3232

3333
class PolarsExpr:
@@ -232,6 +232,46 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
232232

233233
return PolarsNamespace(version=self._version)
234234

235+
def is_close(
236+
self,
237+
other: Self | NumericLiteral,
238+
*,
239+
abs_tol: float,
240+
rel_tol: float,
241+
nans_equal: bool,
242+
) -> Self:
243+
left = self.native
244+
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
245+
246+
if self._backend_version < (1, 32, 0):
247+
lower_bound = right.abs()
248+
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
249+
250+
# Values are close if abs_diff <= tolerance, and both finite
251+
abs_diff = (left - right).abs()
252+
all_ = pl.all_horizontal
253+
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
254+
255+
# Handle infinity cases: infinities are "close" only if they have the same sign
256+
is_same_inf = all_(
257+
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
258+
)
259+
260+
# Handle nan cases:
261+
# * nans_equals = True => if both values are NaN, then True
262+
# * nans_equals = False => if any value is NaN, then False
263+
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
264+
either_nan = left_is_nan | right_is_nan
265+
result = (is_close | is_same_inf) & either_nan.not_()
266+
267+
if nans_equal:
268+
result = result | (left_is_nan & right_is_nan)
269+
else:
270+
result = left.is_close(
271+
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
272+
)
273+
return self._with_native(result)
274+
235275
@property
236276
def dt(self) -> PolarsExprDateTimeNamespace:
237277
return PolarsExprDateTimeNamespace(self)

narwhals/_polars/series.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
IntoDType,
4242
MultiIndexSelector,
4343
NonNestedLiteral,
44+
NumericLiteral,
4445
_1DArray,
4546
)
4647

@@ -90,6 +91,7 @@
9091
"gather_every",
9192
"head",
9293
"is_between",
94+
"is_close",
9395
"is_finite",
9496
"is_first_distinct",
9597
"is_in",
@@ -131,6 +133,11 @@
131133
class PolarsSeries:
132134
_implementation = Implementation.POLARS
133135

136+
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
137+
True: ["breakpoint", "count"],
138+
False: ["count"],
139+
}
140+
134141
def __init__(self, series: pl.Series, *, version: Version) -> None:
135142
self._native_series: pl.Series = series
136143
self._version = version
@@ -478,10 +485,27 @@ def __contains__(self, other: Any) -> bool:
478485
except Exception as e: # noqa: BLE001
479486
raise catch_polars_exception(e) from None
480487

481-
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
482-
True: ["breakpoint", "count"],
483-
False: ["count"],
484-
}
488+
def is_close(
489+
self,
490+
other: Self | NumericLiteral,
491+
*,
492+
abs_tol: float,
493+
rel_tol: float,
494+
nans_equal: bool,
495+
) -> PolarsSeries:
496+
if self._backend_version < (1, 32, 0):
497+
name = self.name
498+
ns = self.__narwhals_namespace__()
499+
other_expr = other._to_expr() if isinstance(other, PolarsSeries) else other
500+
expr = ns.col(name).is_close(
501+
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
502+
)
503+
return self.to_frame().select(expr).get_column(name)
504+
other_series = other.native if isinstance(other, PolarsSeries) else other
505+
result = self.native.is_close(
506+
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
507+
)
508+
return self._with_native(result)
485509

486510
def hist_from_bins(
487511
self, bins: list[float], *, include_breakpoint: bool

narwhals/expr.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten
1414
from narwhals.dtypes import _validate_dtype
15-
from narwhals.exceptions import InvalidOperationError
15+
from narwhals.exceptions import ComputeError, InvalidOperationError
1616
from narwhals.expr_cat import ExprCatNamespace
1717
from narwhals.expr_dt import ExprDateTimeNamespace
1818
from narwhals.expr_list import ExprListNamespace
@@ -2347,6 +2347,97 @@ def sqrt(self) -> Self:
23472347
"""
23482348
return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt())
23492349

2350+
def is_close(
2351+
self,
2352+
other: Self | NumericLiteral,
2353+
*,
2354+
abs_tol: float = 0.0,
2355+
rel_tol: float = 1e-09,
2356+
nans_equal: bool = False,
2357+
) -> Self:
2358+
r"""Check if this expression is close, i.e. almost equal, to the other expression.
2359+
2360+
Two values `a` and `b` are considered close if the following condition holds:
2361+
2362+
$$
2363+
|a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \}
2364+
$$
2365+
2366+
Arguments:
2367+
other: Values to compare with.
2368+
abs_tol: Absolute tolerance. This is the maximum allowed absolute difference
2369+
between two values. Must be non-negative.
2370+
rel_tol: Relative tolerance. This is the maximum allowed difference between
2371+
two values, relative to the larger absolute value. Must be in the range
2372+
[0, 1).
2373+
nans_equal: Whether NaN values should be considered equal.
2374+
2375+
Returns:
2376+
Expression of Boolean data type.
2377+
2378+
Notes:
2379+
The implementation of this method is symmetric and mirrors the behavior of
2380+
`math.isclose`. Specifically note that this behavior is different to
2381+
`numpy.isclose`.
2382+
2383+
Examples:
2384+
>>> import duckdb
2385+
>>> import pyarrow as pa
2386+
>>> import narwhals as nw
2387+
>>>
2388+
>>> data = {
2389+
... "x": [1.0, float("inf"), 1.41, None, float("nan")],
2390+
... "y": [1.2, float("inf"), 1.40, None, float("nan")],
2391+
... }
2392+
>>> _table = pa.table(data)
2393+
>>> df_native = duckdb.table("_table")
2394+
>>> df = nw.from_native(df_native)
2395+
>>> df.with_columns(
2396+
... is_close=nw.col("x").is_close(
2397+
... nw.col("y"), abs_tol=0.1, nans_equal=True
2398+
... )
2399+
... )
2400+
┌──────────────────────────────┐
2401+
| Narwhals LazyFrame |
2402+
|------------------------------|
2403+
|┌────────┬────────┬──────────┐|
2404+
|│ x │ y │ is_close │|
2405+
|│ double │ double │ boolean │|
2406+
|├────────┼────────┼──────────┤|
2407+
|│ 1.0 │ 1.2 │ false │|
2408+
|│ inf │ inf │ true │|
2409+
|│ 1.41 │ 1.4 │ true │|
2410+
|│ NULL │ NULL │ NULL │|
2411+
|│ nan │ nan │ true │|
2412+
|└────────┴────────┴──────────┘|
2413+
└──────────────────────────────┘
2414+
"""
2415+
if abs_tol < 0:
2416+
msg = f"`abs_tol` must be non-negative but got {abs_tol}"
2417+
raise ComputeError(msg)
2418+
2419+
if not (0 <= rel_tol < 1):
2420+
msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}"
2421+
raise ComputeError(msg)
2422+
2423+
kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal}
2424+
return self.__class__(
2425+
lambda plx: apply_n_ary_operation(
2426+
plx,
2427+
lambda *exprs: exprs[0].is_close(exprs[1], **kwargs),
2428+
self,
2429+
other,
2430+
str_as_lit=False,
2431+
),
2432+
combine_metadata(
2433+
self,
2434+
other,
2435+
str_as_lit=False,
2436+
allow_multi_output=False,
2437+
to_single_output=False,
2438+
),
2439+
)
2440+
23502441
@property
23512442
def str(self) -> ExprStringNamespace[Self]:
23522443
return ExprStringNamespace(self)

0 commit comments

Comments
 (0)