Skip to content

Commit 575122d

Browse files
type hints for primitives in foreacasting module
1 parent 1fa24e5 commit 575122d

File tree

3 files changed

+47
-40
lines changed

3 files changed

+47
-40
lines changed

aeon/forecasting/_ets.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ class ETSForecaster(BaseForecaster):
8383

8484
def __init__(
8585
self,
86-
error_type=ADDITIVE,
87-
trend_type=NONE,
88-
seasonality_type=NONE,
89-
seasonal_period=1,
86+
error_type: int = ADDITIVE,
87+
trend_type: int = NONE,
88+
seasonality_type: int = NONE,
89+
seasonal_period: int = 1,
9090
alpha: float = 0.1,
91-
beta=0.01,
92-
gamma=0.01,
93-
phi=0.99,
94-
horizon=1,
91+
beta: float = 0.01,
92+
gamma: float = 0.01,
93+
phi: float = 0.99,
94+
horizon: int = 1,
9595
):
9696
self.error_type = error_type
9797
self.trend_type = trend_type
@@ -190,14 +190,14 @@ def _predict(self, y=None, exog=None):
190190
@njit(nogil=NOGIL, cache=CACHE)
191191
def _fit_numba(
192192
data,
193-
error_type,
194-
trend_type,
195-
seasonality_type,
196-
seasonal_period,
197-
alpha,
198-
beta,
199-
gamma,
200-
phi,
193+
error_type: int = ADDITIVE,
194+
trend_type: int = NONE,
195+
seasonality_type: int = NONE,
196+
seasonal_period: int = 1,
197+
alpha: float = 0.1,
198+
beta: float = 0.01,
199+
gamma: float = 0.01,
200+
phi: float = 0.99,
201201
):
202202
n_timepoints = len(data)
203203
level, trend, seasonality = _initialise(
@@ -236,15 +236,15 @@ def _fit_numba(
236236

237237

238238
def _predict_numba(
239-
trend_type,
240-
seasonality_type,
241-
level,
242-
trend,
243-
seasonality,
244-
phi,
245-
horizon,
246-
n_timepoints,
247-
seasonal_period,
239+
trend_type: int = ADDITIVE,
240+
seasonality_type: int = NONE,
241+
level: float = 0.1,
242+
seasonality: float = 0.05,
243+
phi: float = 0.99,
244+
horizon: int = 1,
245+
n_timepoints: int = 5,
246+
seasonal_period: int = 1,
247+
trend: float = 0.01,
248248
):
249249
# Generate forecasts based on the final values of level, trend, and seasonals
250250
if phi == 1: # No damping case
@@ -264,7 +264,7 @@ def _predict_numba(
264264

265265

266266
@njit(nogil=NOGIL, cache=CACHE)
267-
def _initialise(trend_type, seasonality_type, seasonal_period, data):
267+
def _initialise(trend_type: int, seasonality_type: int, seasonal_period: int, data):
268268
"""
269269
Initialize level, trend, and seasonality values for the ETS model.
270270
@@ -307,17 +307,17 @@ def _initialise(trend_type, seasonality_type, seasonal_period, data):
307307

308308
@njit(nogil=NOGIL, cache=CACHE)
309309
def _update_states(
310-
error_type,
311-
trend_type,
312-
seasonality_type,
313-
level,
314-
trend,
315-
seasonality,
310+
error_type: int,
311+
trend_type: int,
312+
seasonality_type: int,
313+
level: float,
314+
trend: float,
315+
seasonality: float,
316316
data_item: int,
317-
alpha,
318-
beta,
319-
gamma,
320-
phi,
317+
alpha: float,
318+
beta: float,
319+
gamma: float,
320+
phi: float,
321321
):
322322
"""
323323
Update level, trend, and seasonality components.
@@ -374,7 +374,14 @@ def _update_states(
374374

375375

376376
@njit(nogil=NOGIL, cache=CACHE)
377-
def _predict_value(trend_type, seasonality_type, level, trend, seasonality, phi):
377+
def _predict_value(
378+
trend_type: int,
379+
seasonality_type: int,
380+
level: float,
381+
trend: float,
382+
seasonality: float,
383+
phi: float,
384+
):
378385
"""
379386
380387
Generate various useful values, including the next fitted value.

aeon/forecasting/_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class RegressionForecaster(BaseForecaster):
3737
with sklearn regressors.
3838
"""
3939

40-
def __init__(self, window, horizon=1, regressor=None):
40+
def __init__(self, window: int, horizon: int = 1, regressor=None):
4141
self.window = window
4242
self.regressor = regressor
4343
super().__init__(horizon=horizon, axis=1)
@@ -123,7 +123,7 @@ def _forecast(self, y, exog=None):
123123
return self.predict()
124124

125125
@classmethod
126-
def _get_test_params(cls, parameter_set="default"):
126+
def _get_test_params(cls, parameter_set: str = "default"):
127127
"""Return testing parameter settings for the estimator.
128128
129129
Parameters

aeon/forecasting/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class BaseForecaster(BaseSeriesEstimator):
3636
"y_inner_type": "np.ndarray",
3737
}
3838

39-
def __init__(self, horizon, axis):
39+
def __init__(self, horizon: int, axis: int):
4040
self.horizon = horizon
4141
self.meta_ = None # Meta data related to y on the last fit
4242
super().__init__(axis)

0 commit comments

Comments
 (0)