Skip to content

Commit

Permalink
use TypeVar to get correct output type (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Aug 5, 2024
1 parent 12e7678 commit 8188dab
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 308 deletions.
5 changes: 4 additions & 1 deletion nbs/compat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"#| export\n",
"import warnings\n",
"from functools import wraps\n",
"from typing import Union\n",
"from typing import TypeVar, Union\n",
"\n",
"import pandas as pd"
]
Expand All @@ -34,11 +34,13 @@
"source": [
"#| export\n",
"try:\n",
" import polars\n",
" import polars as pl\n",
" from polars import DataFrame as pl_DataFrame\n",
" from polars import Expr as pl_Expr\n",
" from polars import Series as pl_Series\n",
"\n",
" DFType = TypeVar(\"DFType\", pd.DataFrame, polars.DataFrame)\n",
" POLARS_INSTALLED = True\n",
"except ImportError:\n",
" pl = None\n",
Expand All @@ -52,6 +54,7 @@
" class pl_Series:\n",
" ...\n",
"\n",
" DFType = pd.DataFrame\n",
" POLARS_INSTALLED = False\n",
"\n",
"try:\n",
Expand Down
58 changes: 44 additions & 14 deletions nbs/data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@
"outputs": [],
"source": [
"#| export\n",
"from typing import List, Optional\n",
"from typing import List, Literal, Optional, overload\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from utilsforecast.compat import DataFrame, pl"
"from utilsforecast.compat import DataFrame, pl, pl_DataFrame"
]
},
{
Expand All @@ -66,6 +66,7 @@
"outputs": [],
"source": [
"#| export\n",
"@overload\n",
"def generate_series(\n",
" n_series: int,\n",
" freq: str = 'D',\n",
Expand All @@ -77,7 +78,36 @@
" static_as_categorical: bool = True,\n",
" n_models: int = 0,\n",
" level: Optional[List[float]] = None,\n",
" engine: str = 'pandas',\n",
" engine: Literal['pandas'] = 'pandas',\n",
") -> pd.DataFrame: ...\n",
"\n",
"@overload\n",
"def generate_series(\n",
" n_series: int,\n",
" freq: str = 'D',\n",
" min_length: int = 50,\n",
" max_length: int = 500,\n",
" n_static_features: int = 0,\n",
" equal_ends: bool = False,\n",
" with_trend: bool = False,\n",
" static_as_categorical: bool = True,\n",
" n_models: int = 0,\n",
" level: Optional[List[float]] = None,\n",
" engine: Literal['polars'] = 'polars',\n",
") -> pl_DataFrame: ...\n",
"\n",
"def generate_series(\n",
" n_series: int,\n",
" freq: str = 'D',\n",
" min_length: int = 50,\n",
" max_length: int = 500,\n",
" n_static_features: int = 0,\n",
" equal_ends: bool = False,\n",
" with_trend: bool = False,\n",
" static_as_categorical: bool = True,\n",
" n_models: int = 0,\n",
" level: Optional[List[float]] = None,\n",
" engine: Literal['pandas', 'polars'] = 'pandas',\n",
" seed: int = 0,\n",
") -> DataFrame:\n",
" \"\"\"Generate Synthetic Panel Series.\n",
Expand Down Expand Up @@ -116,7 +146,7 @@
" Synthetic panel with columns [`unique_id`, `ds`, `y`] and exogenous features.\n",
" \"\"\"\n",
" available_engines = ['pandas', 'polars']\n",
" engine = engine.lower()\n",
" engine = engine.lower() # type: ignore\n",
" if engine not in available_engines:\n",
" raise ValueError(\n",
" f\"{engine} is not a correct engine; available options: {available_engines}\"\n",
Expand Down Expand Up @@ -188,18 +218,18 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/data.py#L15){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/data.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### generate_series\n",
"\n",
"> generate_series (n_series:int, freq:str='D', min_length:int=50,\n",
"> max_length:int=500, n_static_features:int=0,\n",
"> equal_ends:bool=False, with_trend:bool=False,\n",
"> static_as_categorical:bool=True, n_models:int=0,\n",
"> level:Optional[List[float]]=None, engine:str='pandas',\n",
"> seed:int=0)\n",
"> level:Optional[List[float]]=None,\n",
"> engine:Literal['pandas','polars']='pandas', seed:int=0)\n",
"\n",
"Generate Synthetic Panel Series.\n",
"*Generate Synthetic Panel Series.*\n",
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
Expand All @@ -213,25 +243,25 @@
"| static_as_categorical | bool | True | Static features should have a categorical data type. |\n",
"| n_models | int | 0 | Number of models predictions to simulate. |\n",
"| level | Optional | None | Confidence level for intervals to simulate for each model. |\n",
"| engine | str | pandas | Output Dataframe type. |\n",
"| engine | Literal | pandas | Output Dataframe type. |\n",
"| seed | int | 0 | Random seed used for generating the data. |\n",
"| **Returns** | **Union** | | **Synthetic panel with columns [`unique_id`, `ds`, `y`] and exogenous features.** |"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/data.py#L15){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/data.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### generate_series\n",
"\n",
"> generate_series (n_series:int, freq:str='D', min_length:int=50,\n",
"> max_length:int=500, n_static_features:int=0,\n",
"> equal_ends:bool=False, with_trend:bool=False,\n",
"> static_as_categorical:bool=True, n_models:int=0,\n",
"> level:Optional[List[float]]=None, engine:str='pandas',\n",
"> seed:int=0)\n",
"> level:Optional[List[float]]=None,\n",
"> engine:Literal['pandas','polars']='pandas', seed:int=0)\n",
"\n",
"Generate Synthetic Panel Series.\n",
"*Generate Synthetic Panel Series.*\n",
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
Expand All @@ -245,7 +275,7 @@
"| static_as_categorical | bool | True | Static features should have a categorical data type. |\n",
"| n_models | int | 0 | Number of models predictions to simulate. |\n",
"| level | Optional | None | Confidence level for intervals to simulate for each model. |\n",
"| engine | str | pandas | Output Dataframe type. |\n",
"| engine | Literal | pandas | Output Dataframe type. |\n",
"| seed | int | 0 | Random seed used for generating the data. |\n",
"| **Returns** | **Union** | | **Synthetic panel with columns [`unique_id`, `ds`, `y`] and exogenous features.** |"
]
Expand Down
36 changes: 16 additions & 20 deletions nbs/evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"import pandas as pd\n",
"\n",
"import utilsforecast.processing as ufp\n",
"from utilsforecast.compat import DataFrame, pl"
"from utilsforecast.compat import DFType, pl"
]
},
{
Expand Down Expand Up @@ -90,16 +90,16 @@
"source": [
"#| export\n",
"def evaluate(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" metrics: List[Callable],\n",
" models: Optional[List[str]] = None,\n",
" train_df: Optional[DataFrame] = None,\n",
" train_df: Optional[DFType] = None,\n",
" level: Optional[List[int]] = None,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
" target_col: str = 'y',\n",
" agg_fn: Optional[str] = None,\n",
") -> DataFrame:\n",
") -> DFType:\n",
" \"\"\"Evaluate forecast using different metrics.\n",
" \n",
" Parameters\n",
Expand Down Expand Up @@ -272,11 +272,9 @@
"\n",
"### evaluate\n",
"\n",
"> evaluate\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe.frame.Da\n",
"> taFrame], metrics:List[Callable],\n",
"> models:Optional[List[str]]=None, train_df:Union[pandas.core.fra\n",
"> me.DataFrame,polars.dataframe.frame.DataFrame,NoneType]=None,\n",
"> evaluate (df:~DFType, metrics:List[Callable],\n",
"> models:Optional[List[str]]=None,\n",
"> train_df:Optional[~DFType]=None,\n",
"> level:Optional[List[int]]=None, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> agg_fn:Optional[str]=None)\n",
Expand All @@ -285,16 +283,16 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
"| df | Union | | Forecasts to evaluate.<br>Must have `id_col`, `time_col`, `target_col` and models' predictions. |\n",
"| df | DFType | | Forecasts to evaluate.<br>Must have `id_col`, `time_col`, `target_col` and models' predictions. |\n",
"| metrics | List | | Functions with arguments `df`, `models`, `id_col`, `target_col` and optionally `train_df`. |\n",
"| models | Optional | None | Names of the models to evaluate.<br>If `None` will use every column in the dataframe after removing id, time and target. |\n",
"| train_df | Union | None | Training set. Used to evaluate metrics such as `mase`. |\n",
"| train_df | Optional | None | Training set. Used to evaluate metrics such as `mase`. |\n",
"| level | Optional | None | Prediction interval levels. Used to compute losses that rely on quantiles. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n",
"| target_col | str | y | Column that contains the target. |\n",
"| agg_fn | Optional | None | Statistic to compute on the scores by id to reduce them to a single number. |\n",
"| **Returns** | **Union** | | **Metrics with one row per (id, metric) combination and one column per model.<br>If `agg_fn` is not `None`, there is only one row per metric.** |"
"| **Returns** | **DFType** | | **Metrics with one row per (id, metric) combination and one column per model.<br>If `agg_fn` is not `None`, there is only one row per metric.** |"
],
"text/plain": [
"---\n",
Expand All @@ -303,11 +301,9 @@
"\n",
"### evaluate\n",
"\n",
"> evaluate\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe.frame.Da\n",
"> taFrame], metrics:List[Callable],\n",
"> models:Optional[List[str]]=None, train_df:Union[pandas.core.fra\n",
"> me.DataFrame,polars.dataframe.frame.DataFrame,NoneType]=None,\n",
"> evaluate (df:~DFType, metrics:List[Callable],\n",
"> models:Optional[List[str]]=None,\n",
"> train_df:Optional[~DFType]=None,\n",
"> level:Optional[List[int]]=None, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> agg_fn:Optional[str]=None)\n",
Expand All @@ -316,16 +312,16 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
"| df | Union | | Forecasts to evaluate.<br>Must have `id_col`, `time_col`, `target_col` and models' predictions. |\n",
"| df | DFType | | Forecasts to evaluate.<br>Must have `id_col`, `time_col`, `target_col` and models' predictions. |\n",
"| metrics | List | | Functions with arguments `df`, `models`, `id_col`, `target_col` and optionally `train_df`. |\n",
"| models | Optional | None | Names of the models to evaluate.<br>If `None` will use every column in the dataframe after removing id, time and target. |\n",
"| train_df | Union | None | Training set. Used to evaluate metrics such as `mase`. |\n",
"| train_df | Optional | None | Training set. Used to evaluate metrics such as `mase`. |\n",
"| level | Optional | None | Prediction interval levels. Used to compute losses that rely on quantiles. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n",
"| target_col | str | y | Column that contains the target. |\n",
"| agg_fn | Optional | None | Statistic to compute on the scores by id to reduce them to a single number. |\n",
"| **Returns** | **Union** | | **Metrics with one row per (id, metric) combination and one column per model.<br>If `agg_fn` is not `None`, there is only one row per metric.** |"
"| **Returns** | **DFType** | | **Metrics with one row per (id, metric) combination and one column per model.<br>If `agg_fn` is not `None`, there is only one row per metric.** |"
]
},
"execution_count": null,
Expand Down
26 changes: 13 additions & 13 deletions nbs/feature_engineering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"import pandas as pd\n",
"\n",
"import utilsforecast.processing as ufp\n",
"from utilsforecast.compat import DataFrame, pl, pl_DataFrame, pl_Expr\n",
"from utilsforecast.compat import DFType, DataFrame, pl, pl_DataFrame, pl_Expr\n",
"from utilsforecast.validation import validate_format, validate_freq"
]
},
Expand All @@ -61,13 +61,13 @@
"_Features = Tuple[List[str], np.ndarray, np.ndarray]\n",
"\n",
"def _add_features(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" freq: str,\n",
" h: int,\n",
" id_col: str,\n",
" time_col: str,\n",
" f: Callable[[np.ndarray, int], _Features],\n",
") -> Tuple[DataFrame, DataFrame]:\n",
") -> Tuple[DFType, DFType]:\n",
" # validations\n",
" if not isinstance(h, int) or h < 0:\n",
" raise ValueError('`h` must be a non-negative integer')\n",
Expand Down Expand Up @@ -156,14 +156,14 @@
"source": [
"#| export\n",
"def fourier(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" freq: str,\n",
" season_length: int,\n",
" k: int,\n",
" h: int = 0,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
") -> Tuple[DataFrame, DataFrame]:\n",
") -> Tuple[DFType, DFType]:\n",
" \"\"\"Compute fourier seasonal terms for training and forecasting\n",
"\n",
" Parameters\n",
Expand Down Expand Up @@ -558,12 +558,12 @@
"source": [
"#| export\n",
"def trend(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" freq: str,\n",
" h: int = 0,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
") -> Tuple[DataFrame, DataFrame]:\n",
") -> Tuple[DFType, DFType]:\n",
" \"\"\"Add a trend column with consecutive integers for training and forecasting\n",
"\n",
" Parameters\n",
Expand Down Expand Up @@ -855,10 +855,10 @@
" return feat_name, feat_vals\n",
"\n",
"def _add_time_features(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" features: List[Union[str, Callable]],\n",
" time_col: str = 'ds',\n",
") -> DataFrame:\n",
") -> DFType:\n",
" df = ufp.copy_if_pandas(df, deep=False)\n",
" unique_times = df[time_col].unique()\n",
" if isinstance(df, pd.DataFrame):\n",
Expand Down Expand Up @@ -891,13 +891,13 @@
"source": [
"#| export\n",
"def time_features(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" freq: str,\n",
" features: List[Union[str, Callable]],\n",
" h: int = 0,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds', \n",
") -> Tuple[DataFrame, DataFrame]:\n",
") -> Tuple[DFType, DFType]:\n",
" \"\"\"Compute timestamp-based features for training and forecasting\n",
"\n",
" Parameters\n",
Expand Down Expand Up @@ -1194,13 +1194,13 @@
"source": [
"#| export\n",
"def pipeline(\n",
" df: DataFrame,\n",
" df: DFType,\n",
" features: List[Callable],\n",
" freq: str,\n",
" h: int = 0,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
") -> Tuple[DataFrame, DataFrame]:\n",
") -> Tuple[DFType, DFType]:\n",
" \"\"\"Compute several features for training and forecasting\n",
"\n",
" Parameters\n",
Expand Down
Loading

0 comments on commit 8188dab

Please sign in to comment.