Skip to content

Commit 5b9b85c

Browse files
fix: Polars >=1.0 fixes, kindly provided by tjader in #256
1 parent ef23263 commit 5b9b85c

File tree

8 files changed

+149
-88
lines changed

8 files changed

+149
-88
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ faer-ext = {version = "0.1.0", features = ["ndarray"]}
1717
ndarray = "0.15.6"
1818
serde = {version = "*", features=["derive"]}
1919
hashbrown = {version = "0.14.2", features=["nightly"]}
20-
numpy = "*"
20+
numpy = "*"

functime/_compat.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import polars as pl
2+
from typing import Any, Iterable
3+
from pathlib import Path
4+
try:
5+
from polars.plugins import register_plugin_function
6+
except ImportError:
7+
8+
def register_plugin_function(*,
9+
plugin_path: Path | str,
10+
function_name: str,
11+
args: 'IntoExpr | Iterable[IntoExpr]',
12+
kwargs: dict[str, Any] | None = None,
13+
is_elementwise: bool = False,
14+
# changes_length: bool = False,
15+
returns_scalar: bool = False,
16+
cast_to_supertype: bool = False,
17+
# input_wildcard_expansion: bool = False,
18+
# pass_name_to_apply: bool = False,
19+
):
20+
21+
expr = args[0]
22+
args1 = args[1:]
23+
expr.register_plugin(
24+
lib=plugin_path,
25+
args=args1,
26+
symbol=function_name,
27+
is_elementwise=is_elementwise,
28+
returns_scalar=returns_scalar,
29+
kwargs=kwargs,
30+
cast_to_supertypes=cast_to_supertype,
31+
)

functime/feature_extractors.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
import logging
44
import math
5+
from pathlib import Path
56
from typing import List, Mapping, Optional, Sequence, Union
67

78
import numpy as np
89
import polars as pl
910
from polars.type_aliases import ClosedInterval
10-
from polars.utils.udfs import _get_shared_lib_location
1111

1212
# from numpy.linalg import lstsq
1313
from scipy.linalg import lstsq
1414
from scipy.signal import find_peaks_cwt, ricker, welch
1515
from scipy.spatial import KDTree
1616

17+
from functime._compat import register_plugin_function, rle_fields
1718
from functime._functime_rust import rs_faer_lstsq1
1819
from functime._utils import warn_is_unstable
1920
from functime.type_aliases import DetrendMethod
@@ -34,7 +35,12 @@
3435

3536
# from polars.type_aliases import IntoExpr
3637

37-
lib = _get_shared_lib_location(__file__)
38+
try:
39+
from polars.utils.udfs import _get_shared_lib_location
40+
41+
lib = _get_shared_lib_location(__file__)
42+
except ImportError:
43+
lib = Path(__file__).parent
3844

3945

4046
def absolute_energy(x: TIME_SERIES_T) -> FLOAT_INT_EXPR:
@@ -995,12 +1001,16 @@ def longest_streak_above_mean(x: TIME_SERIES_T) -> INT_EXPR:
9951001
"""
9961002
y = (x > x.mean()).rle()
9971003
if isinstance(x, pl.Series):
998-
result = y.filter(y.struct.field("values")).struct.field("lengths").max()
1004+
result = (
1005+
y.filter(y.struct.field(rle_fields["value"]))
1006+
.struct.field(rle_fields["len"])
1007+
.max()
1008+
)
9991009
return 0 if result is None else result
10001010
else:
10011011
return (
1002-
y.filter(y.struct.field("values"))
1003-
.struct.field("lengths")
1012+
y.filter(y.struct.field(rle_fields["value"]))
1013+
.struct.field(rle_fields["len"])
10041014
.max()
10051015
.fill_null(0)
10061016
)
@@ -1024,12 +1034,16 @@ def longest_streak_below_mean(x: TIME_SERIES_T) -> INT_EXPR:
10241034
"""
10251035
y = (x < x.mean()).rle()
10261036
if isinstance(x, pl.Series):
1027-
result = y.filter(y.struct.field("values")).struct.field("lengths").max()
1037+
result = (
1038+
y.filter(y.struct.field(rle_fields["value"]))
1039+
.struct.field(rle_fields["len"])
1040+
.max()
1041+
)
10281042
return 0 if result is None else result
10291043
else:
10301044
return (
1031-
y.filter(y.struct.field("values"))
1032-
.struct.field("lengths")
1045+
y.filter(y.struct.field(rle_fields["value"]))
1046+
.struct.field(rle_fields["len"])
10331047
.max()
10341048
.fill_null(0)
10351049
)
@@ -1752,7 +1766,7 @@ def streak_length_stats(x: TIME_SERIES_T, above: bool, threshold: float) -> MAP_
17521766
else:
17531767
y = (x.diff() <= threshold).rle()
17541768

1755-
y = y.filter(y.struct.field("values")).struct.field("lengths")
1769+
y = y.filter(y.struct.field(rle_fields["value"])).struct.field(rle_fields["len"])
17561770
if isinstance(x, pl.Series):
17571771
return {
17581772
"min": y.min() or 0,
@@ -1797,12 +1811,16 @@ def longest_streak_above(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:
17971811

17981812
y = (x.diff() >= threshold).rle()
17991813
if isinstance(x, pl.Series):
1800-
streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max()
1814+
streak_max = (
1815+
y.filter(y.struct.field(rle_fields["value"]))
1816+
.struct.field(rle_fields["len"])
1817+
.max()
1818+
)
18011819
return 0 if streak_max is None else streak_max
18021820
else:
18031821
return (
1804-
y.filter(y.struct.field("values"))
1805-
.struct.field("lengths")
1822+
y.filter(y.struct.field(rle_fields["value"]))
1823+
.struct.field(rle_fields["len"])
18061824
.max()
18071825
.fill_null(0)
18081826
)
@@ -1827,12 +1845,16 @@ def longest_streak_below(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:
18271845
"""
18281846
y = (x.diff() <= threshold).rle()
18291847
if isinstance(x, pl.Series):
1830-
streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max()
1848+
streak_max = (
1849+
y.filter(y.struct.field(rle_fields["value"]))
1850+
.struct.field(rle_fields["len"])
1851+
.max()
1852+
)
18311853
return 0 if streak_max is None else streak_max
18321854
else:
18331855
return (
1834-
y.filter(y.struct.field("values"))
1835-
.struct.field("lengths")
1856+
y.filter(y.struct.field(rle_fields["value"]))
1857+
.struct.field(rle_fields["len"])
18361858
.max()
18371859
.fill_null(0)
18381860
)
@@ -2255,9 +2277,10 @@ def lempel_ziv_complexity(
22552277
https://github.com/Naereen/Lempel-Ziv_Complexity/tree/master
22562278
https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv_complexity
22572279
"""
2258-
out = (self._expr > threshold).register_plugin(
2259-
lib=lib,
2260-
symbol="pl_lempel_ziv_complexity",
2280+
out = register_plugin_function(
2281+
args=[self._expr > threshold],
2282+
plugin_path=lib,
2283+
function_name="pl_lempel_ziv_complexity",
22612284
is_elementwise=False,
22622285
returns_scalar=True,
22632286
)
@@ -2766,16 +2789,17 @@ def cusum(
27662789
-------
27672790
An expression of the output
27682791
"""
2769-
return self._expr.register_plugin(
2770-
lib=lib,
2771-
symbol="cusum",
2792+
return register_plugin_function(
2793+
args=[self._expr],
2794+
plugin_path=lib,
2795+
function_name="cusum",
27722796
kwargs={
27732797
"threshold": threshold,
27742798
"drift": drift,
27752799
"warmup_period": warmup_period,
27762800
},
27772801
is_elementwise=False,
2778-
cast_to_supertypes=True,
2802+
cast_to_supertype=True,
27792803
)
27802804

27812805
def frac_diff(
@@ -2815,14 +2839,15 @@ def frac_diff(
28152839
if min_weight is None and window_size is None:
28162840
raise ValueError("Either min_weight or window_size must be specified.")
28172841

2818-
return self._expr.register_plugin(
2819-
lib=lib,
2820-
symbol="frac_diff",
2842+
return register_plugin_function(
2843+
args=[self._expr],
2844+
plugin_path=lib,
2845+
function_name="frac_diff",
28212846
kwargs={
28222847
"d": d,
28232848
"min_weight": min_weight,
28242849
"window_size": window_size,
28252850
},
28262851
is_elementwise=False,
2827-
cast_to_supertypes=True,
2852+
cast_to_supertype=True,
28282853
)

functime/forecasting/snaive.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,7 @@ def _fit(self, y: pl.LazyFrame, X: Optional[pl.LazyFrame] = None):
3030
sp = self.sp
3131
# BUG: Cannot run the following in lazy streaming mode?
3232
# Causes internal error: entered unreachable code
33-
y_pred = (
34-
y.sort(idx_cols)
35-
.set_sorted(idx_cols)
36-
.group_by(entity_col)
37-
.agg(pl.col(target_col).tail(sp))
38-
)
33+
y_pred = y.sort(idx_cols).group_by(entity_col).agg(pl.col(target_col).tail(sp))
3934
artifacts = {"y_pred": y_pred}
4035
return artifacts
4136

functime/preprocessing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,10 @@ def optimizer(fun):
602602
lmbds = gb.agg(
603603
PL_NUMERIC_COLS(entity_col, time_col)
604604
.map_elements(
605-
lambda x: boxcox_normmax(x, method=method, optimizer=optimizer)
605+
lambda x: boxcox_normmax(x, method=method, optimizer=optimizer),
606+
returns_scalar=True,
607+
return_dtype=pl.Float64,
606608
)
607-
.cast(pl.Float64())
608609
.name.suffix("__lmbd")
609610
)
610611
# Step 2. Transform
@@ -667,6 +668,7 @@ def transform(X: pl.LazyFrame) -> pl.LazyFrame:
667668
PL_NUMERIC_COLS(entity_col, time_col)
668669
.map_elements(
669670
lambda x: yeojohnson_normmax(x.to_numpy(), brack),
671+
returns_scalar=True,
670672
return_dtype=pl.Float64,
671673
)
672674
.name.suffix("__lmbd")

tests/conftest.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,12 @@ def m4_dataset(request):
8787
def load_panel_data(path: str) -> pl.LazyFrame:
8888
return (
8989
pl.read_parquet(path)
90-
.pipe(
91-
lambda df: df.select(
92-
[
93-
pl.col("series").cast(pl.Categorical),
94-
pl.col("time").cast(pl.Int16),
95-
pl.col(df.columns[2]).cast(pl.Float32),
96-
]
97-
)
90+
.with_columns(
91+
pl.col("series").str.replace(" ", "").cast(pl.Categorical),
92+
pl.col("time").cast(pl.Int16),
93+
pl.all().exclude(["series", "time"]).cast(pl.Float32),
9894
)
99-
.with_columns(pl.col("series").str.replace(" ", ""))
10095
.sort(["series", "time"])
101-
.set_sorted(["series", "time"])
10296
)
10397

10498
def update_test_time_ranges(y_train, y_test):

tests/test_fourier.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,25 @@
1010
from functime.seasonality import add_fourier_terms
1111

1212

13-
@pytest.mark.parametrize("freq,sp", [("1h", 24), ("1d", 365), ("1w", 52)])
14-
def test_fourier_with_dates(freq: str, sp: int):
15-
timestamps = pl.date_range(
16-
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
17-
)
13+
@pytest.mark.parametrize(
14+
"freq,sp, use_date",
15+
[
16+
("1h", 24, False),
17+
("1d", 365, False),
18+
("1w", 52, False),
19+
("1d", 365, True),
20+
("1w", 52, True),
21+
],
22+
)
23+
def test_fourier_with_timestamps(freq: str, sp: int, use_date: bool):
24+
if use_date:
25+
timestamps = pl.date_range(
26+
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
27+
)
28+
else:
29+
timestamps = pl.datetime_range(
30+
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
31+
)
1832
n_timestamps = len(timestamps)
1933
idx_timestamps = timestamps.arg_sort() + 1
2034
entities = pl.concat(

0 commit comments

Comments
 (0)