@@ -83,15 +83,15 @@ class ETSForecaster(BaseForecaster):
83
83
84
84
def __init__ (
85
85
self ,
86
- error_type = ADDITIVE ,
87
- trend_type = NONE ,
88
- seasonality_type = NONE ,
89
- seasonal_period = 1 ,
86
+ error_type : int = ADDITIVE ,
87
+ trend_type : int = NONE ,
88
+ seasonality_type : int = NONE ,
89
+ seasonal_period : int = 1 ,
90
90
alpha : float = 0.1 ,
91
- beta = 0.01 ,
92
- gamma = 0.01 ,
93
- phi = 0.99 ,
94
- horizon = 1 ,
91
+ beta : float = 0.01 ,
92
+ gamma : float = 0.01 ,
93
+ phi : float = 0.99 ,
94
+ horizon : int = 1 ,
95
95
):
96
96
self .error_type = error_type
97
97
self .trend_type = trend_type
@@ -190,14 +190,14 @@ def _predict(self, y=None, exog=None):
190
190
@njit (nogil = NOGIL , cache = CACHE )
191
191
def _fit_numba (
192
192
data ,
193
- error_type ,
194
- trend_type ,
195
- seasonality_type ,
196
- seasonal_period ,
197
- alpha ,
198
- beta ,
199
- gamma ,
200
- phi ,
193
+ error_type : int = ADDITIVE ,
194
+ trend_type : int = NONE ,
195
+ seasonality_type : int = NONE ,
196
+ seasonal_period : int = 1 ,
197
+ alpha : float = 0.1 ,
198
+ beta : float = 0.01 ,
199
+ gamma : float = 0.01 ,
200
+ phi : float = 0.99 ,
201
201
):
202
202
n_timepoints = len (data )
203
203
level , trend , seasonality = _initialise (
@@ -236,15 +236,15 @@ def _fit_numba(
236
236
237
237
238
238
def _predict_numba (
239
- trend_type ,
240
- seasonality_type ,
241
- level ,
242
- trend ,
243
- seasonality ,
244
- phi ,
245
- horizon ,
246
- n_timepoints ,
247
- seasonal_period ,
239
+ trend_type : int = ADDITIVE ,
240
+ seasonality_type : int = NONE ,
241
+ level : float = 0.1 ,
242
+ seasonality : float = 0.05 ,
243
+ phi : float = 0.99 ,
244
+ horizon : int = 1 ,
245
+ n_timepoints : int = 5 ,
246
+ seasonal_period : int = 1 ,
247
+ trend : float = 0.01 ,
248
248
):
249
249
# Generate forecasts based on the final values of level, trend, and seasonals
250
250
if phi == 1 : # No damping case
@@ -264,7 +264,7 @@ def _predict_numba(
264
264
265
265
266
266
@njit (nogil = NOGIL , cache = CACHE )
267
- def _initialise (trend_type , seasonality_type , seasonal_period , data ):
267
+ def _initialise (trend_type : int , seasonality_type : int , seasonal_period : int , data ):
268
268
"""
269
269
Initialize level, trend, and seasonality values for the ETS model.
270
270
@@ -307,17 +307,17 @@ def _initialise(trend_type, seasonality_type, seasonal_period, data):
307
307
308
308
@njit (nogil = NOGIL , cache = CACHE )
309
309
def _update_states (
310
- error_type ,
311
- trend_type ,
312
- seasonality_type ,
313
- level ,
314
- trend ,
315
- seasonality ,
310
+ error_type : int ,
311
+ trend_type : int ,
312
+ seasonality_type : int ,
313
+ level : float ,
314
+ trend : float ,
315
+ seasonality : float ,
316
316
data_item : int ,
317
- alpha ,
318
- beta ,
319
- gamma ,
320
- phi ,
317
+ alpha : float ,
318
+ beta : float ,
319
+ gamma : float ,
320
+ phi : float ,
321
321
):
322
322
"""
323
323
Update level, trend, and seasonality components.
@@ -374,7 +374,14 @@ def _update_states(
374
374
375
375
376
376
@njit (nogil = NOGIL , cache = CACHE )
377
- def _predict_value (trend_type , seasonality_type , level , trend , seasonality , phi ):
377
+ def _predict_value (
378
+ trend_type : int ,
379
+ seasonality_type : int ,
380
+ level : float ,
381
+ trend : float ,
382
+ seasonality : float ,
383
+ phi : float ,
384
+ ):
378
385
"""
379
386
380
387
Generate various useful values, including the next fitted value.
0 commit comments