Skip to content

[ENH]Type hints/forecasting #2737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 25, 2025
2 changes: 1 addition & 1 deletion aeon/base/_base_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _preprocess_series(self, X, axis, store_metadata):
self.metadata_ = meta
return self._convert_X(X, axis)

def _check_X(self, X, axis):
def _check_X(self, X, axis: int = 0):
"""Check input X is valid.

Check if the input data is a compatible type, and that this estimator is
Expand Down
83 changes: 45 additions & 38 deletions aeon/forecasting/_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ class ETSForecaster(BaseForecaster):

def __init__(
self,
error_type=ADDITIVE,
trend_type=NONE,
seasonality_type=NONE,
seasonal_period=1,
alpha=0.1,
beta=0.01,
gamma=0.01,
phi=0.99,
horizon=1,
error_type: int = ADDITIVE,
trend_type: int = NONE,
seasonality_type: int = NONE,
seasonal_period: int = 1,
alpha: float = 0.1,
beta: float = 0.01,
gamma: float = 0.01,
phi: float = 0.99,
horizon: int = 1,
):
self.error_type = error_type
self.trend_type = trend_type
Expand Down Expand Up @@ -190,14 +190,14 @@ def _predict(self, y=None, exog=None):
@njit(nogil=NOGIL, cache=CACHE)
def _fit_numba(
data,
error_type,
trend_type,
seasonality_type,
seasonal_period,
alpha,
beta,
gamma,
phi,
error_type: int,
trend_type: int,
seasonality_type: int,
seasonal_period: int,
alpha: float,
beta: float,
gamma: float,
phi: float,
):
n_timepoints = len(data)
level, trend, seasonality = _initialise(
Expand Down Expand Up @@ -236,15 +236,15 @@ def _fit_numba(


def _predict_numba(
trend_type,
seasonality_type,
level,
trend,
seasonality,
phi,
horizon,
n_timepoints,
seasonal_period,
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
phi: float,
horizon: int,
n_timepoints: int,
seasonal_period: int,
):
# Generate forecasts based on the final values of level, trend, and seasonals
if phi == 1: # No damping case
Expand All @@ -264,7 +264,7 @@ def _predict_numba(


@njit(nogil=NOGIL, cache=CACHE)
def _initialise(trend_type, seasonality_type, seasonal_period, data):
def _initialise(trend_type: int, seasonality_type: int, seasonal_period: int, data):
"""
Initialize level, trend, and seasonality values for the ETS model.

Expand Down Expand Up @@ -307,17 +307,17 @@ def _initialise(trend_type, seasonality_type, seasonal_period, data):

@njit(nogil=NOGIL, cache=CACHE)
def _update_states(
error_type,
trend_type,
seasonality_type,
level,
trend,
seasonality,
error_type: int,
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
data_item: int,
alpha,
beta,
gamma,
phi,
alpha: float,
beta: float,
gamma: float,
phi: float,
):
"""
Update level, trend, and seasonality components.
Expand Down Expand Up @@ -374,7 +374,14 @@ def _update_states(


@njit(nogil=NOGIL, cache=CACHE)
def _predict_value(trend_type, seasonality_type, level, trend, seasonality, phi):
def _predict_value(
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
phi: float,
):
"""

Generate various useful values, including the next fitted value.
Expand Down
4 changes: 2 additions & 2 deletions aeon/forecasting/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RegressionForecaster(BaseForecaster):
with sklearn regressors.
"""

def __init__(self, window, horizon=1, regressor=None):
def __init__(self, window: int, horizon: int = 1, regressor=None):
self.window = window
self.regressor = regressor
super().__init__(horizon=horizon, axis=1)
Expand Down Expand Up @@ -123,7 +123,7 @@ def _forecast(self, y, exog=None):
return self.predict()

@classmethod
def _get_test_params(cls, parameter_set="default"):
def _get_test_params(cls, parameter_set: str = "default"):
"""Return testing parameter settings for the estimator.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion aeon/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseForecaster(BaseSeriesEstimator):
"y_inner_type": "np.ndarray",
}

def __init__(self, horizon, axis):
def __init__(self, horizon: int, axis: int):
self.horizon = horizon
self.meta_ = None # Meta data related to y on the last fit
super().__init__(axis)
Expand Down