Skip to content

Commit

Permalink
feat: Allow float in interpolate_by by column (#18015)
Browse files Browse the repository at this point in the history
  • Loading branch information
agossard authored Aug 18, 2024
1 parent a284174 commit 49747c1
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 24 deletions.
20 changes: 16 additions & 4 deletions crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn interpolate_impl_by_sorted<T, F, I>(
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut Vec<T::Native>),
{
// This implementation differs from pandas as that boundary None's are not removed.
Expand Down Expand Up @@ -169,7 +169,7 @@ fn interpolate_impl_by<T, F, I>(
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),
{
// This implementation differs from pandas as that boundary None's are not removed.
Expand Down Expand Up @@ -273,7 +273,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
) -> PolarsResult<Series>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
if is_sorted {
Expand All @@ -290,6 +290,18 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
}

match (s.dtype(), by.dtype()) {
(DataType::Float64, DataType::Float64) => {
func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Float32) => {
func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float64) => {
func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float32) => {
func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int64) => {
func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)
},
Expand Down Expand Up @@ -326,7 +338,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
UInt64, or UInt32")
UInt64, UInt32, Float32 or Float64")
},
}
}
85 changes: 65 additions & 20 deletions py-polars/tests/unit/operations/test_interpolate_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
pl.Int32,
pl.UInt64,
pl.UInt32,
pl.Float32,
pl.Float64,
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -116,22 +118,42 @@ def test_interpolate_by_leading_nulls() -> None:
assert_frame_equal(result, expected)


def test_interpolate_by_trailing_nulls() -> None:
df = pl.DataFrame(
{
"times": [
date(2020, 1, 1),
date(2020, 1, 3),
date(2020, 1, 10),
date(2020, 1, 11),
date(2020, 1, 12),
date(2020, 1, 13),
],
"values": [1, None, None, 5, None, None],
}
)
@pytest.mark.parametrize("dataset", ["floats", "dates"])
def test_interpolate_by_trailing_nulls(dataset: str) -> None:
input_data = {
"dates": pl.DataFrame(
{
"times": [
date(2020, 1, 1),
date(2020, 1, 3),
date(2020, 1, 10),
date(2020, 1, 11),
date(2020, 1, 12),
date(2020, 1, 13),
],
"values": [1, None, None, 5, None, None],
}
),
"floats": pl.DataFrame(
{
"times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1],
"values": [1, None, None, 5, None, None],
}
),
}

expected_data = {
"dates": pl.DataFrame(
{"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}
),
"floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}),
}

df = input_data[dataset]
expected = expected_data[dataset]

result = df.select(pl.col("values").interpolate_by("times"))
expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]})

assert_frame_equal(result, expected)
result = (
df.sort("times", descending=True)
Expand All @@ -142,16 +164,28 @@ def test_interpolate_by_trailing_nulls() -> None:
assert_frame_equal(result, expected)


@given(data=st.data())
def test_interpolate_vs_numpy(data: st.DataObject) -> None:
@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64]))
def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None:
if x_dtype == pl.Float64:
by_strategy = st.floats(
min_value=-1e150,
max_value=1e150,
allow_nan=False,
allow_infinity=False,
allow_subnormal=False,
)
else:
by_strategy = None

dataframe = (
data.draw(
dataframes(
[
column(
"ts",
dtype=pl.Date,
dtype=x_dtype,
allow_null=False,
strategy=by_strategy,
),
column(
"value",
Expand All @@ -166,13 +200,24 @@ def test_interpolate_vs_numpy(data: st.DataObject) -> None:
.fill_nan(None)
.unique("ts")
)

if x_dtype == pl.Float64:
assume(not dataframe["ts"].is_nan().any())
assume(not dataframe["ts"].is_null().any())
assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any())

assume(not dataframe["value"].is_null().all())
assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any())

dataframe = dataframe.sort("ts")

result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"]

mask = dataframe["value"].is_not_null()
x = dataframe["ts"].to_numpy().astype("int64")
xp = dataframe["ts"].filter(mask).to_numpy().astype("int64")

np_dtype = "int64" if x_dtype == pl.Date else "float64"
x = dataframe["ts"].to_numpy().astype(np_dtype)
xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype)
yp = dataframe["value"].filter(mask).to_numpy().astype("float64")
interp = np.interp(x, xp, yp)
# Polars preserves nulls on boundaries, but NumPy doesn't.
Expand Down

0 comments on commit 49747c1

Please sign in to comment.