2
2
3
3
import logging
4
4
import math
5
+ from pathlib import Path
5
6
from typing import List , Mapping , Optional , Sequence , Union
6
7
7
8
import numpy as np
8
9
import polars as pl
9
10
from polars .type_aliases import ClosedInterval
10
- from polars .utils .udfs import _get_shared_lib_location
11
11
12
12
# from numpy.linalg import lstsq
13
13
from scipy .linalg import lstsq
14
14
from scipy .signal import find_peaks_cwt , ricker , welch
15
15
from scipy .spatial import KDTree
16
16
17
+ from functime ._compat import register_plugin_function , rle_fields
17
18
from functime ._functime_rust import rs_faer_lstsq1
18
19
from functime ._utils import warn_is_unstable
19
20
from functime .type_aliases import DetrendMethod
34
35
35
36
# from polars.type_aliases import IntoExpr
36
37
37
- lib = _get_shared_lib_location (__file__ )
38
+ try :
39
+ from polars .utils .udfs import _get_shared_lib_location
40
+
41
+ lib = _get_shared_lib_location (__file__ )
42
+ except ImportError :
43
+ lib = Path (__file__ ).parent
38
44
39
45
40
46
def absolute_energy (x : TIME_SERIES_T ) -> FLOAT_INT_EXPR :
@@ -995,12 +1001,16 @@ def longest_streak_above_mean(x: TIME_SERIES_T) -> INT_EXPR:
995
1001
"""
996
1002
y = (x > x .mean ()).rle ()
997
1003
if isinstance (x , pl .Series ):
998
- result = y .filter (y .struct .field ("values" )).struct .field ("lengths" ).max ()
1004
+ result = (
1005
+ y .filter (y .struct .field (rle_fields ["value" ]))
1006
+ .struct .field (rle_fields ["len" ])
1007
+ .max ()
1008
+ )
999
1009
return 0 if result is None else result
1000
1010
else :
1001
1011
return (
1002
- y .filter (y .struct .field ("values" ))
1003
- .struct .field ("lengths" )
1012
+ y .filter (y .struct .field (rle_fields [ "value" ] ))
1013
+ .struct .field (rle_fields [ "len" ] )
1004
1014
.max ()
1005
1015
.fill_null (0 )
1006
1016
)
@@ -1024,12 +1034,16 @@ def longest_streak_below_mean(x: TIME_SERIES_T) -> INT_EXPR:
1024
1034
"""
1025
1035
y = (x < x .mean ()).rle ()
1026
1036
if isinstance (x , pl .Series ):
1027
- result = y .filter (y .struct .field ("values" )).struct .field ("lengths" ).max ()
1037
+ result = (
1038
+ y .filter (y .struct .field (rle_fields ["value" ]))
1039
+ .struct .field (rle_fields ["len" ])
1040
+ .max ()
1041
+ )
1028
1042
return 0 if result is None else result
1029
1043
else :
1030
1044
return (
1031
- y .filter (y .struct .field ("values" ))
1032
- .struct .field ("lengths" )
1045
+ y .filter (y .struct .field (rle_fields [ "value" ] ))
1046
+ .struct .field (rle_fields [ "len" ] )
1033
1047
.max ()
1034
1048
.fill_null (0 )
1035
1049
)
@@ -1752,7 +1766,7 @@ def streak_length_stats(x: TIME_SERIES_T, above: bool, threshold: float) -> MAP_
1752
1766
else :
1753
1767
y = (x .diff () <= threshold ).rle ()
1754
1768
1755
- y = y .filter (y .struct .field ("values" )).struct .field ("lengths" )
1769
+ y = y .filter (y .struct .field (rle_fields [ "value" ] )).struct .field (rle_fields [ "len" ] )
1756
1770
if isinstance (x , pl .Series ):
1757
1771
return {
1758
1772
"min" : y .min () or 0 ,
@@ -1797,12 +1811,16 @@ def longest_streak_above(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:
1797
1811
1798
1812
y = (x .diff () >= threshold ).rle ()
1799
1813
if isinstance (x , pl .Series ):
1800
- streak_max = y .filter (y .struct .field ("values" )).struct .field ("lengths" ).max ()
1814
+ streak_max = (
1815
+ y .filter (y .struct .field (rle_fields ["value" ]))
1816
+ .struct .field (rle_fields ["len" ])
1817
+ .max ()
1818
+ )
1801
1819
return 0 if streak_max is None else streak_max
1802
1820
else :
1803
1821
return (
1804
- y .filter (y .struct .field ("values" ))
1805
- .struct .field ("lengths" )
1822
+ y .filter (y .struct .field (rle_fields [ "value" ] ))
1823
+ .struct .field (rle_fields [ "len" ] )
1806
1824
.max ()
1807
1825
.fill_null (0 )
1808
1826
)
@@ -1827,12 +1845,16 @@ def longest_streak_below(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:
1827
1845
"""
1828
1846
y = (x .diff () <= threshold ).rle ()
1829
1847
if isinstance (x , pl .Series ):
1830
- streak_max = y .filter (y .struct .field ("values" )).struct .field ("lengths" ).max ()
1848
+ streak_max = (
1849
+ y .filter (y .struct .field (rle_fields ["value" ]))
1850
+ .struct .field (rle_fields ["len" ])
1851
+ .max ()
1852
+ )
1831
1853
return 0 if streak_max is None else streak_max
1832
1854
else :
1833
1855
return (
1834
- y .filter (y .struct .field ("values" ))
1835
- .struct .field ("lengths" )
1856
+ y .filter (y .struct .field (rle_fields [ "value" ] ))
1857
+ .struct .field (rle_fields [ "len" ] )
1836
1858
.max ()
1837
1859
.fill_null (0 )
1838
1860
)
@@ -2255,9 +2277,10 @@ def lempel_ziv_complexity(
2255
2277
https://github.com/Naereen/Lempel-Ziv_Complexity/tree/master
2256
2278
https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv_complexity
2257
2279
"""
2258
- out = (self ._expr > threshold ).register_plugin (
2259
- lib = lib ,
2260
- symbol = "pl_lempel_ziv_complexity" ,
2280
+ out = register_plugin_function (
2281
+ args = [self ._expr > threshold ],
2282
+ plugin_path = lib ,
2283
+ function_name = "pl_lempel_ziv_complexity" ,
2261
2284
is_elementwise = False ,
2262
2285
returns_scalar = True ,
2263
2286
)
@@ -2766,16 +2789,17 @@ def cusum(
2766
2789
-------
2767
2790
An expression of the output
2768
2791
"""
2769
- return self ._expr .register_plugin (
2770
- lib = lib ,
2771
- symbol = "cusum" ,
2792
+ return register_plugin_function (
2793
+ args = [self ._expr ],
2794
+ plugin_path = lib ,
2795
+ function_name = "cusum" ,
2772
2796
kwargs = {
2773
2797
"threshold" : threshold ,
2774
2798
"drift" : drift ,
2775
2799
"warmup_period" : warmup_period ,
2776
2800
},
2777
2801
is_elementwise = False ,
2778
- cast_to_supertypes = True ,
2802
+ cast_to_supertype = True ,
2779
2803
)
2780
2804
2781
2805
def frac_diff (
@@ -2815,14 +2839,15 @@ def frac_diff(
2815
2839
if min_weight is None and window_size is None :
2816
2840
raise ValueError ("Either min_weight or window_size must be specified." )
2817
2841
2818
- return self ._expr .register_plugin (
2819
- lib = lib ,
2820
- symbol = "frac_diff" ,
2842
+ return register_plugin_function (
2843
+ args = [self ._expr ],
2844
+ plugin_path = lib ,
2845
+ function_name = "frac_diff" ,
2821
2846
kwargs = {
2822
2847
"d" : d ,
2823
2848
"min_weight" : min_weight ,
2824
2849
"window_size" : window_size ,
2825
2850
},
2826
2851
is_elementwise = False ,
2827
- cast_to_supertypes = True ,
2852
+ cast_to_supertype = True ,
2828
2853
)
0 commit comments