From efa92f90bd23069ac2286e161d6c3b7541eb3915 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 21 Apr 2024 10:31:32 +0100 Subject: [PATCH] fix: ewm_mean_by was skipping initial nulls when it was already sorted by "by" column (#15812) --- crates/polars-ops/src/series/ops/ewm_by.rs | 4 +++- py-polars/tests/unit/operations/test_ewm_by.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs index f47e03239ae3..a14947467f52 100644 --- a/crates/polars-ops/src/series/ops/ewm_by.rs +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -129,7 +129,9 @@ where out.push(Some(prev_result)); skip_rows = idx + 1; break; - }; + } else { + out.push(None) + } } values .iter() diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py index fcd87fd83f5d..aaace3e67cf4 100644 --- a/py-polars/tests/unit/operations/test_ewm_by.py +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -12,7 +12,8 @@ from polars.type_aliases import PolarsIntegerType, TimeUnit -def test_ewma_by_date() -> None: +@pytest.mark.parametrize("sort", [True, False]) +def test_ewma_by_date(sort: bool) -> None: df = pl.LazyFrame( { "values": [3.0, 1.0, 2.0, None, 4.0], @@ -25,6 +26,8 @@ def test_ewma_by_date() -> None: ], } ) + if sort: + df = df.sort("times") result = df.select( pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), )