Skip to content

Commit

Permalink
fix: Expr.sign should preserve dtype (#18446)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Aug 29, 2024
1 parent 6146350 commit b3172aa
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 64 deletions.
53 changes: 23 additions & 30 deletions crates/polars-plan/src/dsl/function_expr/sign.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,34 @@
use num::{One, Zero};
use polars_core::export::num;
use DataType::*;
use polars_core::with_match_physical_numeric_polars_type;

use super::*;

pub(super) fn sign(s: &Series) -> PolarsResult<Series> {
match s.dtype() {
Float32 => {
let ca = s.f32().unwrap();
sign_float(ca)
},
Float64 => {
let ca = s.f64().unwrap();
sign_float(ca)
},
dt if dt.is_numeric() => {
let s = s.cast(&Float64)?;
sign(&s)
},
dt => polars_bail!(opq = sign, dt),
}
let dt = s.dtype();
polars_ensure!(dt.is_numeric(), opq = sign, dt);
with_match_physical_numeric_polars_type!(dt, |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref();
Ok(sign_impl(ca))
})
}

fn sign_float<T>(ca: &ChunkedArray<T>) -> PolarsResult<Series>
fn sign_impl<T>(ca: &ChunkedArray<T>) -> Series
where
T: PolarsFloatType,
T::Native: num::Float,
T: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
ca.apply_values(signum_improved).into_series().cast(&Int64)
}

// Wrapper for the signum function that handles +/-0.0 inputs differently
// See discussion here: https://github.com/rust-lang/rust/issues/57543
fn signum_improved<F: num::Float>(v: F) -> F {
if v.is_zero() {
v
} else {
v.signum()
}
ca.apply_values(|x| {
if x < T::Native::zero() {
T::Native::zero() - T::Native::one()
} else if x > T::Native::zero() {
T::Native::one()
} else {
// Returning x here ensures we return NaN for NaN input, and
// maintain the sign for signed zeroes (although we don't really
// care about the latter).
x
}
})
.into_series()
}
29 changes: 15 additions & 14 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8746,30 +8746,31 @@ def upper_bound(self) -> Expr:

def sign(self) -> Expr:
"""
Compute the element-wise indication of the sign.
Compute the element-wise sign function on numeric types.
The returned values can be -1, 0, or 1:
The returned value is computed as follows:
* -1 if x < 0.
* 0 if x == 0.
* 1 if x > 0.
* -1 if x < 0.
* 1 if x > 0.
* x otherwise (typically 0, but could be NaN if the input is).
(null values are preserved as-is).
Null values are preserved as-is, and the dtype of the input is preserved.
Examples
--------
>>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, None]})
>>> df.select(pl.col("a").sign())
shape: (5, 1)
>>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, float("nan"), None]})
>>> df.select(pl.col.a.sign())
shape: (6, 1)
┌──────┐
│ a │
│ --- │
i64
f64
╞══════╡
│ -1 │
│ 0 │
│ 0 │
│ 1 │
│ -1.0 │
│ -0.0 │
│ 0.0 │
│ 1.0 │
│ NaN │
│ null │
└──────┘
"""
Expand Down
27 changes: 14 additions & 13 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4991,27 +4991,28 @@ def mode(self) -> Series:

def sign(self) -> Series:
"""
Compute the element-wise indication of the sign.
Compute the element-wise sign function on numeric types.
The returned values can be -1, 0, or 1:
The returned value is computed as follows:
* -1 if x < 0.
* 0 if x == 0.
* 1 if x > 0.
* -1 if x < 0.
* 1 if x > 0.
* x otherwise (typically 0, but could be NaN if the input is).
(null values are preserved as-is).
Null values are preserved as-is, and the dtype of the input is preserved.
Examples
--------
>>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None])
>>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None])
>>> s.sign()
shape: (5,)
Series: 'a' [i64]
shape: (6,)
Series: 'a' [f64]
[
-1
0
0
1
-1.0
-0.0
0.0
1.0
NaN
null
]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,6 @@ def test_parse_apply_raw_functions() -> None:
):
df1 = lf.select(pl.col("a").map_elements(func)).collect()
df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()
if func_name == "sign":
# note: Polars' 'sign' function returns an Int64, while numpy's
# 'sign' function returns a Float64
df1 = df1.with_columns(pl.col("a").cast(pl.Int64))
assert_frame_equal(df1, df2)

# test bare 'json.loads'
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,8 +1747,8 @@ def test_sign() -> None:
assert_series_equal(a.sign(), expected)

# Floats
a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None])
expected = pl.Series("a", [-1, 0, 0, 1, None])
a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None])
expected = pl.Series("a", [-1.0, 0.0, 0.0, 1.0, float("nan"), None])
assert_series_equal(a.sign(), expected)

# Invalid input
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/sql/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_div() -> None:
[
[-0.0995024875621891, 2.85714285714286, 12.0, None, -15.92356687898089],
[-1, 2, 12, None, -16],
[-1, 1, 1, None, -1],
[-1.0, 1.0, 1.0, None, -1.0],
],
schema=["a_div_b", "a_floordiv_b", "b_sign"],
),
Expand Down

0 comments on commit b3172aa

Please sign in to comment.