|
12 | 12 | ) |
13 | 13 | from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten |
14 | 14 | from narwhals.dtypes import _validate_dtype |
15 | | -from narwhals.exceptions import InvalidOperationError |
| 15 | +from narwhals.exceptions import ComputeError, InvalidOperationError |
16 | 16 | from narwhals.expr_cat import ExprCatNamespace |
17 | 17 | from narwhals.expr_dt import ExprDateTimeNamespace |
18 | 18 | from narwhals.expr_list import ExprListNamespace |
@@ -2347,6 +2347,97 @@ def sqrt(self) -> Self: |
2347 | 2347 | """ |
2348 | 2348 | return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt()) |
2349 | 2349 |
|
| 2350 | + def is_close( |
| 2351 | + self, |
| 2352 | + other: Self | NumericLiteral, |
| 2353 | + *, |
| 2354 | + abs_tol: float = 0.0, |
| 2355 | + rel_tol: float = 1e-09, |
| 2356 | + nans_equal: bool = False, |
| 2357 | + ) -> Self: |
| 2358 | + r"""Check if this expression is close, i.e. almost equal, to the other expression. |
| 2359 | +
|
| 2360 | + Two values `a` and `b` are considered close if the following condition holds: |
| 2361 | +
|
| 2362 | + $$ |
| 2363 | + |a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \} |
| 2364 | + $$ |
| 2365 | +
|
| 2366 | + Arguments: |
| 2367 | + other: Values to compare with. |
| 2368 | + abs_tol: Absolute tolerance. This is the maximum allowed absolute difference |
| 2369 | + between two values. Must be non-negative. |
| 2370 | + rel_tol: Relative tolerance. This is the maximum allowed difference between |
| 2371 | + two values, relative to the larger absolute value. Must be in the range |
| 2372 | + [0, 1). |
| 2373 | + nans_equal: Whether NaN values should be considered equal. |
| 2374 | +
|
| 2375 | + Returns: |
| 2376 | + Expression of Boolean data type. |
| 2377 | +
|
| 2378 | + Notes: |
| 2379 | + The implementation of this method is symmetric and mirrors the behavior of |
| 2380 | + `math.isclose`. Specifically note that this behavior is different to |
| 2381 | + `numpy.isclose`. |
| 2382 | +
|
| 2383 | + Examples: |
| 2384 | + >>> import duckdb |
| 2385 | + >>> import pyarrow as pa |
| 2386 | + >>> import narwhals as nw |
| 2387 | + >>> |
| 2388 | + >>> data = { |
| 2389 | + ... "x": [1.0, float("inf"), 1.41, None, float("nan")], |
| 2390 | + ... "y": [1.2, float("inf"), 1.40, None, float("nan")], |
| 2391 | + ... } |
| 2392 | + >>> _table = pa.table(data) |
| 2393 | + >>> df_native = duckdb.table("_table") |
| 2394 | + >>> df = nw.from_native(df_native) |
| 2395 | + >>> df.with_columns( |
| 2396 | + ... is_close=nw.col("x").is_close( |
| 2397 | + ... nw.col("y"), abs_tol=0.1, nans_equal=True |
| 2398 | + ... ) |
| 2399 | + ... ) |
| 2400 | + ┌──────────────────────────────┐ |
| 2401 | + | Narwhals LazyFrame | |
| 2402 | + |------------------------------| |
| 2403 | + |┌────────┬────────┬──────────┐| |
| 2404 | + |│ x │ y │ is_close │| |
| 2405 | + |│ double │ double │ boolean │| |
| 2406 | + |├────────┼────────┼──────────┤| |
| 2407 | + |│ 1.0 │ 1.2 │ false │| |
| 2408 | + |│ inf │ inf │ true │| |
| 2409 | + |│ 1.41 │ 1.4 │ true │| |
| 2410 | + |│ NULL │ NULL │ NULL │| |
| 2411 | + |│ nan │ nan │ true │| |
| 2412 | + |└────────┴────────┴──────────┘| |
| 2413 | + └──────────────────────────────┘ |
| 2414 | + """ |
| 2415 | + if abs_tol < 0: |
| 2416 | + msg = f"`abs_tol` must be non-negative but got {abs_tol}" |
| 2417 | + raise ComputeError(msg) |
| 2418 | + |
| 2419 | + if not (0 <= rel_tol < 1): |
| 2420 | + msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" |
| 2421 | + raise ComputeError(msg) |
| 2422 | + |
| 2423 | + kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal} |
| 2424 | + return self.__class__( |
| 2425 | + lambda plx: apply_n_ary_operation( |
| 2426 | + plx, |
| 2427 | + lambda *exprs: exprs[0].is_close(exprs[1], **kwargs), |
| 2428 | + self, |
| 2429 | + other, |
| 2430 | + str_as_lit=False, |
| 2431 | + ), |
| 2432 | + combine_metadata( |
| 2433 | + self, |
| 2434 | + other, |
| 2435 | + str_as_lit=False, |
| 2436 | + allow_multi_output=False, |
| 2437 | + to_single_output=False, |
| 2438 | + ), |
| 2439 | + ) |
| 2440 | + |
2350 | 2441 | @property |
2351 | 2442 | def str(self) -> ExprStringNamespace[Self]: |
2352 | 2443 | return ExprStringNamespace(self) |
|
0 commit comments