From 1705e6af74fa899fc636b4f94b9bf6b9ed5d6969 Mon Sep 17 00:00:00 2001 From: Kin Date: Wed, 21 Jun 2023 09:54:54 -0400 Subject: [PATCH] [Doc] `aggregate` showdoc + external reconciliation tutorials' improvements (#214) * Improvements to GluonTS compatibility tutorial * Tutorials' sidebar order * Improved tutorials' references and introduction * Improved utils' documentation, particularly external forecast adapters --- hierarchicalforecast/utils.py | 216 +- .../HierarchicalForecast-GluonTS.ipynb | 2381 ++++++++--------- nbs/examples/MLFrameworksExample.ipynb | 41 +- nbs/sidebar.yml | 2 +- nbs/utils.ipynb | 452 ++-- 5 files changed, 1548 insertions(+), 1544 deletions(-) diff --git a/hierarchicalforecast/utils.py b/hierarchicalforecast/utils.py index 5af6906..eed380b 100644 --- a/hierarchicalforecast/utils.py +++ b/hierarchicalforecast/utils.py @@ -1,9 +1,9 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb. # %% auto 0 -__all__ = ['HierarchicalPlot'] +__all__ = ['aggregate', 'HierarchicalPlot'] -# %% ../nbs/utils.ipynb 2 +# %% ../nbs/utils.ipynb 3 import sys import timeit from itertools import chain @@ -16,7 +16,7 @@ plt.rcParams['font.family'] = 'serif' -# %% ../nbs/utils.ipynb 4 +# %% ../nbs/utils.ipynb 5 class CodeTimer: def __init__(self, name=None, verbose=True): self.name = " '" + name + "'" if name else '' @@ -31,7 +31,7 @@ def __exit__(self, exc_type, exc_value, traceback): print('Code block' + self.name + \ ' took:\t{0:.5f}'.format(self.took) + ' seconds') -# %% ../nbs/utils.ipynb 5 +# %% ../nbs/utils.ipynb 6 def is_strictly_hierarchical(S: np.ndarray, tags: Dict[str, np.ndarray]): # main idea: @@ -49,7 +49,7 @@ def is_strictly_hierarchical(S: np.ndarray, nodes = levels_.popitem()[1].size return paths == nodes -# %% ../nbs/utils.ipynb 6 +# %% ../nbs/utils.ipynb 7 def cov2corr(cov, return_std=False): """ convert covariance matrix to correlation matrix @@ -68,105 +68,7 @@ def cov2corr(cov, return_std=False): else: return corr -# %% ../nbs/utils.ipynb 7 -# convert levels to output quantile names -def level_to_outputs(level:Iterable[int]): - """ Converts list of levels into output names matching StatsForecast and NeuralForecast methods. - - **Parameters:**
- `level`: int list [0,100]. Probability levels for prediction intervals.
- - **Returns:**
- `output_names`: str list. String list with output column names. - """ - qs = sum([[50-l/2, 50+l/2] for l in level], []) - output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], []) - - sort_idx = np.argsort(qs) - quantiles = np.array(qs)[sort_idx] - - # Add default median - quantiles = np.concatenate([np.array([50]), quantiles]) / 100 - output_names = list(np.array(output_names)[sort_idx]) - output_names.insert(0, '-median') - - return quantiles, output_names - -# convert quantiles to output quantile names -def quantiles_to_outputs(quantiles:Iterable[float]): - """Converts list of quantiles into output names matching StatsForecast and NeuralForecast methods. - - **Parameters:**
- `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
- - **Returns:**
- `output_names`: str list. String list with output column names. - """ - output_names = [] - for q in quantiles: - if q<.50: - output_names.append(f'-lo-{np.round(100-200*q,2)}') - elif q>.50: - output_names.append(f'-hi-{np.round(100-200*(1-q),2)}') - else: - output_names.append('-median') - return quantiles, output_names - -# %% ../nbs/utils.ipynb 8 -# given input array of sample forecasts and inptut quantiles/levels, -# output a Pandas Dataframe with columns of quantile predictions -def samples_to_quantiles_df(samples:np.ndarray, - unique_ids:Iterable[str], - dates:Iterable, - quantiles:Optional[Iterable[float]] = None, - level:Optional[Iterable[int]] = None, - model_name:Optional[str] = "model"): - """ Transform Samples into HierarchicalForecast input. - Auxiliary function to create compatible HierarchicalForecast input Y_hat_df dataframe. - - **Parameters:**
- `samples`: numpy array. Samples from forecast distribution of shape [n_series, n_samples, horizon].
- `unique_ids`: string list. Unique identifiers for each time series.
- `dates`: datetime list. List of forecast dates.
- `quantiles`: float list in [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
- `level`: int list in [0,100]. Probability levels for prediction intervals.
- `model_name`: string. Name of forecasting model.
- - **Returns:**
- `quantiles`: float list in [0., 1.]. quantiles to estimate from y distribution .
- `Y_hat_df`: pd.DataFrame. With base quantile forecasts with columns ds and models to reconcile indexed by unique_id. - """ - - # Get the shape of the array - n_series, n_samples, horizon = samples.shape - - assert n_series == len(unique_ids) - assert horizon == len(dates) - assert (quantiles is not None) ^ (level is not None) #check exactly one of quantiles/levels has been input - - #create initial dictionary - forecasts_mean = np.mean(samples, axis=1).flatten() - unique_ids = np.repeat(unique_ids, horizon) - ds = np.tile(dates, n_series) - data = pd.DataFrame({"unique_id":unique_ids, "ds":ds, model_name:forecasts_mean}) - - #create quantiles and quantile names - quantiles, quantile_names = level_to_outputs(level) if level is not None else quantiles_to_outputs(quantiles) - percentiles = [quantile * 100 for quantile in quantiles] - col_names = np.array([model_name + quantile_name for quantile_name in quantile_names]) - - #add quantiles to dataframe - forecasts_quantiles = np.percentile(samples, percentiles, axis=1) - - forecasts_quantiles = np.transpose(forecasts_quantiles, (1,2,0)) # [Q,H,N] -> [N,H,Q] - forecasts_quantiles = forecasts_quantiles.reshape(-1,len(quantiles)) - - df = pd.DataFrame(data=forecasts_quantiles, - columns=col_names) - - return quantiles, pd.concat([data,df], axis=1).set_index('unique_id') - -# %% ../nbs/utils.ipynb 11 +# %% ../nbs/utils.ipynb 9 def _to_summing_matrix(S_df: pd.DataFrame): """Transforms the DataFrame `df` of hierarchies to a summing matrix S.""" categories = [S_df[col].unique() for col in S_df.columns] @@ -179,7 +81,7 @@ def _to_summing_matrix(S_df: pd.DataFrame): tags = dict(zip(S_df.columns, categories)) return S, tags -# %% ../nbs/utils.ipynb 12 +# %% ../nbs/utils.ipynb 10 def aggregate_before(df: pd.DataFrame, spec: List[List[str]], agg_fn: Callable = np.sum): @@ -221,7 +123,7 @@ def aggregate_before(df: pd.DataFrame, S, tags = _to_summing_matrix(S_df.loc[bottom_hier, hiers_cols]) return Y_df, S, tags -# %% ../nbs/utils.ipynb 13 +# %% ../nbs/utils.ipynb 11 def numpy_balance(*arrs): """ Fast NumPy implementation of balance function. @@ -287,12 +189,14 @@ def _to_summing_dataframe(df: pd.DataFrame, Y_bottom_df.unique_id = Y_bottom_df.unique_id.cat.set_categories(S_df.columns) return Y_bottom_df, S_df, tags +# %% ../nbs/utils.ipynb 12 def aggregate(df: pd.DataFrame, spec: List[List[str]], is_balanced: bool=False): """ Utils Aggregation Function. Aggregates bottom level series contained in the pd.DataFrame `df` according to levels defined in the `spec` list applying the `agg_fn` (sum, mean). + **Parameters:**
`df`: pd.DataFrame with columns `['ds', 'y']` and columns to aggregate.
`spec`: List of levels. Each element of the list contains a list of columns of `df` to aggregate.
@@ -349,7 +253,7 @@ def aggregate(df: pd.DataFrame, Y_df = Y_df.set_index('unique_id').dropna() return Y_df, S_df, tags -# %% ../nbs/utils.ipynb 22 +# %% ../nbs/utils.ipynb 19 class HierarchicalPlot: """ Hierarchical Plot @@ -542,3 +446,101 @@ def plot_hierarchical_predictions_gap(self, plt.legend() plt.grid() plt.show() + +# %% ../nbs/utils.ipynb 34 +# convert levels to output quantile names +def level_to_outputs(level:Iterable[int]): + """ Converts list of levels into output names matching StatsForecast and NeuralForecast methods. + + **Parameters:**
+ `level`: int list [0,100]. Probability levels for prediction intervals.
+ + **Returns:**
+ `output_names`: str list. String list with output column names. + """ + qs = sum([[50-l/2, 50+l/2] for l in level], []) + output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], []) + + sort_idx = np.argsort(qs) + quantiles = np.array(qs)[sort_idx] + + # Add default median + quantiles = np.concatenate([np.array([50]), quantiles]) / 100 + output_names = list(np.array(output_names)[sort_idx]) + output_names.insert(0, '-median') + + return quantiles, output_names + +# convert quantiles to output quantile names +def quantiles_to_outputs(quantiles:Iterable[float]): + """Converts list of quantiles into output names matching StatsForecast and NeuralForecast methods. + + **Parameters:**
+ `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
+ + **Returns:**
+ `output_names`: str list. String list with output column names. + """ + output_names = [] + for q in quantiles: + if q<.50: + output_names.append(f'-lo-{np.round(100-200*q,2)}') + elif q>.50: + output_names.append(f'-hi-{np.round(100-200*(1-q),2)}') + else: + output_names.append('-median') + return quantiles, output_names + +# %% ../nbs/utils.ipynb 35 +# given input array of sample forecasts and inptut quantiles/levels, +# output a Pandas Dataframe with columns of quantile predictions +def samples_to_quantiles_df(samples:np.ndarray, + unique_ids:Iterable[str], + dates:Iterable, + quantiles:Optional[Iterable[float]] = None, + level:Optional[Iterable[int]] = None, + model_name:Optional[str] = "model"): + """ Transform Random Samples into HierarchicalForecast input. + Auxiliary function to create compatible HierarchicalForecast input `Y_hat_df` dataframe. + + **Parameters:**
+ `samples`: numpy array. Samples from forecast distribution of shape [n_series, n_samples, horizon].
+ `unique_ids`: string list. Unique identifiers for each time series.
+ `dates`: datetime list. List of forecast dates.
+ `quantiles`: float list in [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
+ `level`: int list in [0,100]. Probability levels for prediction intervals.
+ `model_name`: string. Name of forecasting model.
+ + **Returns:**
+ `quantiles`: float list in [0., 1.]. quantiles to estimate from y distribution .
+ `Y_hat_df`: pd.DataFrame. With base quantile forecasts with columns ds and models to reconcile indexed by unique_id. + """ + + # Get the shape of the array + n_series, n_samples, horizon = samples.shape + + assert n_series == len(unique_ids) + assert horizon == len(dates) + assert (quantiles is not None) ^ (level is not None) #check exactly one of quantiles/levels has been input + + #create initial dictionary + forecasts_mean = np.mean(samples, axis=1).flatten() + unique_ids = np.repeat(unique_ids, horizon) + ds = np.tile(dates, n_series) + data = pd.DataFrame({"unique_id":unique_ids, "ds":ds, model_name:forecasts_mean}) + + #create quantiles and quantile names + quantiles, quantile_names = level_to_outputs(level) if level is not None else quantiles_to_outputs(quantiles) + percentiles = [quantile * 100 for quantile in quantiles] + col_names = np.array([model_name + quantile_name for quantile_name in quantile_names]) + + #add quantiles to dataframe + forecasts_quantiles = np.percentile(samples, percentiles, axis=1) + + forecasts_quantiles = np.transpose(forecasts_quantiles, (1,2,0)) # [Q,H,N] -> [N,H,Q] + forecasts_quantiles = forecasts_quantiles.reshape(-1,len(quantiles)) + + df = pd.DataFrame(data=forecasts_quantiles, + columns=col_names) + + return quantiles, pd.concat([data,df], axis=1).set_index('unique_id') diff --git a/nbs/examples/HierarchicalForecast-GluonTS.ipynb b/nbs/examples/HierarchicalForecast-GluonTS.ipynb index b0553f5..73da0ae 100644 --- a/nbs/examples/HierarchicalForecast-GluonTS.ipynb +++ b/nbs/examples/HierarchicalForecast-GluonTS.ipynb @@ -1,1238 +1,1173 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GluonTS" + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# HierarchicalForecast with GluonTS Example Notebook\n" - ], - "metadata": { - "id": "OoAaBoCmi7dY" - } - }, - { - "cell_type": "markdown", - "source": [ - "This is an example notebook which shows how HierarchicalForecast's reconciliation capabilities can be integrated with other popular machine learning libraries, in this case GluonTS. \n", - "\n", - "It trains the GluonTS DeepAREstimator on the TourismLarge Hierarchical Dataset, then uses the `samples_to_quantiles_df` util function to transform the output forecasts into a dataframe compatible with HierarchicalForecast's reconciliation functions." - ], - "metadata": { - "id": "b4mhyk4ZjGo-" - } - }, - { - "cell_type": "markdown", - "source": [ - "## 1. Installing packages" - ], - "metadata": { - "id": "DvTLPRB-kKUJ" - } - }, - { - "cell_type": "code", - "source": [ - "%%capture\n", - "!pip install gluonts\n", - "!pip install pytorch_lightning\n", - "!pip install datasetsforecast\n", - "!pip install git+https://github.com/Nixtla/hierarchicalforecast.git" - ], - "metadata": { - "id": "b5nueBJ8dKvp" - }, - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "!pip install mxnet-cu112\n", - "import mxnet as mx\n", - "mx.context.num_gpus()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3iiqozaiQaXW", - "outputId": "2f5d864d-2e76-461e-862a-df612837b9ee" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Requirement already satisfied: mxnet-cu112 in /usr/local/lib/python3.10/dist-packages (1.9.1)\n", - "Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.10/dist-packages (from mxnet-cu112) (1.22.4)\n", - "Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.10/dist-packages (from mxnet-cu112) (2.27.1)\n", - "Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.10/dist-packages (from mxnet-cu112) (0.8.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.20.0->mxnet-cu112) (1.26.15)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.20.0->mxnet-cu112) (2022.12.7)\n", - "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.20.0->mxnet-cu112) (2.0.12)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.20.0->mxnet-cu112) (3.4)\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "1" - ] - }, - "metadata": {}, - "execution_count": 2 - } - ] - }, - { - "cell_type": "code", - "source": [ - "from datasetsforecast.hierarchical import HierarchicalData\n", - "from gluonts.dataset.pandas import PandasDataset\n", - "from gluonts.mx.model.deepar import DeepAREstimator\n", - "from gluonts.mx.trainer import Trainer\n", - "from gluonts.evaluation import make_evaluation_predictions\n", - "\n", - "from hierarchicalforecast.methods import BottomUp, MinTrace\n", - "from hierarchicalforecast.core import HierarchicalReconciliation\n", - "from hierarchicalforecast.evaluation import scaled_crps\n", - "from hierarchicalforecast.utils import samples_to_quantiles_df\n", - "\n", - "import pandas as pd\n", - "import numpy as np" - ], - "metadata": { - "id": "MkMwR-KwLMXh", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ab60b8d9-8c8f-4bc7-d9a3-f6f9af1e07fa" - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.10/dist-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n", - " warnings.warn(\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## 2. Load hierarchical dataset\n" - ], - "metadata": { - "id": "LAe0OXiAkMvT" - } - }, - { - "cell_type": "markdown", - "source": [ - "This detailed Australian Tourism Dataset comes from the National Visitor Survey, managed by the Tourism Research Australia, it is composed of 555 monthly series from 1998 to 2016, it is organized geographically, and purpose of travel. The natural geographical hierarchy comprises seven states, divided further in 27 zones and 76 regions. The purpose of travel categories are holiday, visiting friends and relatives (VFR), business and other. The MinT (Wickramasuriya et al., 2019), among other hierarchical forecasting studies has used the dataset it in the past. The dataset can be accessed in the [MinT reconciliation webpage](https://robjhyndman.com/publications/mint/), although other sources are available.\n", - "\n", - "| Geographical Division | Number of series per division | Number of series per purpose | Total |\n", - "| --- | --- | --- | --- |\n", - "| Australia | 1 | 4 | 5 |\n", - "| States | 7 | 28 | 35 |\n", - "| Zones | 27 | 108 | 135 |\n", - "| Regions | 76 | 304 | 380 |\n", - "| Total | 111 | 444 | 555 |\n" - ], - "metadata": { - "id": "Tmiq65mjkOAG" - } - }, - { - "cell_type": "code", - "source": [ - "dataset = 'TourismLarge'\n", - "Y_df, S_df, tags = HierarchicalData.load(directory = \"./data\", group=dataset)\n", - "Y_df['ds'] = pd.to_datetime(Y_df['ds'])" - ], - "metadata": { - "id": "G0eihW4F7ujp" - }, - "execution_count": 4, - "outputs": [] - }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example notebook demonstrates the compatibility of HierarchicalForecast's reconciliation methods with popular machine-learning libraries, specifically [GluonTS](https://ts.gluon.ai/stable/). \n", + "\n", + "The notebook utilizes the GluonTS DeepAREstimator to create base forecasts for the TourismLarge Hierarchical Dataset. We make the base forecasts compatible with HierarchicalForecast's reconciliation functions via the `samples_to_quantiles_df` utility function that transforms GluonTS' output forecasts into a compatible data frame format. After that, we use HierarchicalForecast to reconcile the base predictions.\n", + "\n", + "**References**
\n", + "- [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)
\n", + "- [Alexander Alexandrov et. al (2020). \"GluonTS: Probabilistic and Neural Time Series Modeling in Python\". Journal of Machine Learning Research.](https://www.jmlr.org/papers/v21/19-820.html)
\n", + "\n", + "You can run these experiments using CPU or GPU with Google Colab.\n", + "\n", + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Installing packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install mxnet-cu112" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import mxnet as mx\n", + "\n", + "assert mx.context.num_gpus()>0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install gluonts\n", + "!pip install datasetsforecast\n", + "!pip install git+https://github.com/Nixtla/hierarchicalforecast.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "source": [ - "def sort_hier_df(Y_df, S_df):\n", - " # sorts unique_id lexicographically\n", - " Y_df.unique_id = Y_df.unique_id.astype('category')\n", - " Y_df.unique_id = Y_df.unique_id.cat.set_categories(S_df.index)\n", - " Y_df = Y_df.sort_values(by=['unique_id', 'ds'])\n", - " return Y_df\n", - "\n", - "Y_df = sort_hier_df(Y_df, S_df)" - ], - "metadata": { - "id": "F46PY79_SS-I" - }, - "execution_count": 5, - "outputs": [] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from datasetsforecast.hierarchical import HierarchicalData\n", + "\n", + "from gluonts.mx.trainer import Trainer\n", + "from gluonts.dataset.pandas import PandasDataset\n", + "from gluonts.mx.model.deepar import DeepAREstimator\n", + "\n", + "from hierarchicalforecast.methods import BottomUp, MinTrace\n", + "from hierarchicalforecast.core import HierarchicalReconciliation\n", + "from hierarchicalforecast.evaluation import scaled_crps\n", + "from hierarchicalforecast.utils import samples_to_quantiles_df" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load hierarchical dataset\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This detailed Australian Tourism Dataset comes from the National Visitor Survey, managed by the Tourism Research Australia, it is composed of 555 monthly series from 1998 to 2016, it is organized geographically, and purpose of travel. The natural geographical hierarchy comprises seven states, divided further in 27 zones and 76 regions. The purpose of travel categories are holiday, visiting friends and relatives (VFR), business and other. The MinT (Wickramasuriya et al., 2019), among other hierarchical forecasting studies has used the dataset it in the past. The dataset can be accessed in the [MinT reconciliation webpage](https://robjhyndman.com/publications/mint/), although other sources are available.\n", + "\n", + "| Geographical Division | Number of series per division | Number of series per purpose | Total |\n", + "| --- | --- | --- | --- |\n", + "| Australia | 1 | 4 | 5 |\n", + "| States | 7 | 28 | 35 |\n", + "| Zones | 27 | 108 | 135 |\n", + "| Regions | 76 | 304 | 380 |\n", + "| Total | 111 | 444 | 555 |\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = 'TourismLarge'\n", + "Y_df, S_df, tags = HierarchicalData.load(directory = \"./data\", group=dataset)\n", + "Y_df['ds'] = pd.to_datetime(Y_df['ds'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def sort_hier_df(Y_df, S_df):\n", + " # sorts unique_id lexicographically\n", + " Y_df.unique_id = Y_df.unique_id.astype('category')\n", + " Y_df.unique_id = Y_df.unique_id.cat.set_categories(S_df.index)\n", + " Y_df = Y_df.sort_values(by=['unique_id', 'ds'])\n", + " return Y_df\n", + "\n", + "Y_df = sort_hier_df(Y_df, S_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "source": [ - "horizon = 12\n", - "\n", - "Y_test_df = Y_df.groupby('unique_id').tail(horizon)\n", - "Y_train_df = Y_df.drop(Y_test_df.index)\n", - "Y_train_df" + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsy
0TotalAll1998-01-0145151.071280
1TotalAll1998-02-0117294.699551
2TotalAll1998-03-0120725.114184
3TotalAll1998-04-0125388.612353
4TotalAll1998-05-0120330.035211
............
126523GBDOth2015-08-0117.683774
126524GBDOth2015-09-010.000000
126525GBDOth2015-10-010.000000
126526GBDOth2015-11-010.000000
126527GBDOth2015-12-010.000000
\n", + "

119880 rows × 3 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " ], - "metadata": { - "id": "tiDBBOzC76MD", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 424 - }, - "outputId": "d76e2080-7605-4dd8-896d-7b64600827e6" - }, - "execution_count": 6, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " unique_id ds y\n", - "0 TotalAll 1998-01-01 45151.071280\n", - "1 TotalAll 1998-02-01 17294.699551\n", - "2 TotalAll 1998-03-01 20725.114184\n", - "3 TotalAll 1998-04-01 25388.612353\n", - "4 TotalAll 1998-05-01 20330.035211\n", - "... ... ... ...\n", - "126523 GBDOth 2015-08-01 17.683774\n", - "126524 GBDOth 2015-09-01 0.000000\n", - "126525 GBDOth 2015-10-01 0.000000\n", - "126526 GBDOth 2015-11-01 0.000000\n", - "126527 GBDOth 2015-12-01 0.000000\n", - "\n", - "[119880 rows x 3 columns]" - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
unique_iddsy
0TotalAll1998-01-0145151.071280
1TotalAll1998-02-0117294.699551
2TotalAll1998-03-0120725.114184
3TotalAll1998-04-0125388.612353
4TotalAll1998-05-0120330.035211
............
126523GBDOth2015-08-0117.683774
126524GBDOth2015-09-010.000000
126525GBDOth2015-10-010.000000
126526GBDOth2015-11-010.000000
126527GBDOth2015-12-010.000000
\n", - "

119880 rows × 3 columns

\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 6 - } + "text/plain": [ + " unique_id ds y\n", + "0 TotalAll 1998-01-01 45151.071280\n", + "1 TotalAll 1998-02-01 17294.699551\n", + "2 TotalAll 1998-03-01 20725.114184\n", + "3 TotalAll 1998-04-01 25388.612353\n", + "4 TotalAll 1998-05-01 20330.035211\n", + "... ... ... ...\n", + "126523 GBDOth 2015-08-01 17.683774\n", + "126524 GBDOth 2015-09-01 0.000000\n", + "126525 GBDOth 2015-10-01 0.000000\n", + "126526 GBDOth 2015-11-01 0.000000\n", + "126527 GBDOth 2015-12-01 0.000000\n", + "\n", + "[119880 rows x 3 columns]" ] - }, - { - "cell_type": "code", - "source": [ - "ds = PandasDataset.from_long_dataframe(Y_train_df, target=\"y\", item_id=\"unique_id\")" - ], - "metadata": { - "id": "MMMFu9fYD_mN" - }, - "execution_count": 7, - "outputs": [] - }, + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "horizon = 12\n", + "\n", + "Y_test_df = Y_df.groupby('unique_id').tail(horizon)\n", + "Y_train_df = Y_df.drop(Y_test_df.index)\n", + "Y_train_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = PandasDataset.from_long_dataframe(Y_train_df, target=\"y\", item_id=\"unique_id\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Fit and Predict Model\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "## 3. Fit and Predict Model\n" - ], - "metadata": { - "id": "gCfgHus-kTYg" - } + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 50/50 [00:11<00:00, 4.39it/s, epoch=1/20, avg_epoch_loss=5.35]\n", + "100%|██████████| 50/50 [00:05<00:00, 8.75it/s, epoch=2/20, avg_epoch_loss=5.22]\n", + "100%|██████████| 50/50 [00:03<00:00, 14.41it/s, epoch=3/20, avg_epoch_loss=5.17]\n", + "100%|██████████| 50/50 [00:02<00:00, 20.76it/s, epoch=4/20, avg_epoch_loss=5.02]\n", + "100%|██████████| 50/50 [00:02<00:00, 19.27it/s, epoch=5/20, avg_epoch_loss=5.05]\n", + "100%|██████████| 50/50 [00:04<00:00, 11.52it/s, epoch=6/20, avg_epoch_loss=5.12]\n", + "100%|██████████| 50/50 [00:03<00:00, 16.59it/s, epoch=7/20, avg_epoch_loss=4.97]\n", + "100%|██████████| 50/50 [00:03<00:00, 16.27it/s, epoch=8/20, avg_epoch_loss=4.97]\n", + "100%|██████████| 50/50 [00:02<00:00, 19.96it/s, epoch=9/20, avg_epoch_loss=5.11]\n", + "100%|██████████| 50/50 [00:04<00:00, 11.36it/s, epoch=10/20, avg_epoch_loss=4.97]\n", + "100%|██████████| 50/50 [00:03<00:00, 16.62it/s, epoch=11/20, avg_epoch_loss=5.05]\n", + "100%|██████████| 50/50 [00:02<00:00, 17.76it/s, epoch=12/20, avg_epoch_loss=5.04]\n", + "100%|██████████| 50/50 [00:02<00:00, 21.56it/s, epoch=13/20, avg_epoch_loss=4.99]\n", + "100%|██████████| 50/50 [00:02<00:00, 20.64it/s, epoch=14/20, avg_epoch_loss=5.03]\n", + "100%|██████████| 50/50 [00:03<00:00, 13.22it/s, epoch=15/20, avg_epoch_loss=4.97]\n", + "100%|██████████| 50/50 [00:02<00:00, 17.79it/s, epoch=16/20, avg_epoch_loss=4.95]\n", + "100%|██████████| 50/50 [00:02<00:00, 18.29it/s, epoch=17/20, avg_epoch_loss=5.02]\n", + "100%|██████████| 50/50 [00:02<00:00, 17.73it/s, epoch=18/20, avg_epoch_loss=5.02]\n", + "100%|██████████| 50/50 [00:02<00:00, 19.10it/s, epoch=19/20, avg_epoch_loss=5.02]\n", + "100%|██████████| 50/50 [00:03<00:00, 13.29it/s, epoch=20/20, avg_epoch_loss=5]\n" + ] }, { - "cell_type": "code", - "source": [ - "estimator = DeepAREstimator(\n", - " freq=\"M\",\n", - " prediction_length=horizon,\n", - " trainer=Trainer(ctx = mx.context.gpu(),\n", - " epochs=20),\n", - ")\n", - "predictor = estimator.train(ds)\n", - "\n", - "forecast_it = predictor.predict(ds, num_samples=1000)\n", - "\n", - "forecasts = list(forecast_it)\n", - "forecasts = np.array([arr.samples for arr in forecasts])\n", - "forecasts.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Qo9DVHHffb_f", - "outputId": "b52346b9-2ab0-4648-a93e-41972823710e" - }, - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "100%|██████████| 50/50 [00:11<00:00, 4.39it/s, epoch=1/20, avg_epoch_loss=5.35]\n", - "100%|██████████| 50/50 [00:05<00:00, 8.75it/s, epoch=2/20, avg_epoch_loss=5.22]\n", - "100%|██████████| 50/50 [00:03<00:00, 14.41it/s, epoch=3/20, avg_epoch_loss=5.17]\n", - "100%|██████████| 50/50 [00:02<00:00, 20.76it/s, epoch=4/20, avg_epoch_loss=5.02]\n", - "100%|██████████| 50/50 [00:02<00:00, 19.27it/s, epoch=5/20, avg_epoch_loss=5.05]\n", - "100%|██████████| 50/50 [00:04<00:00, 11.52it/s, epoch=6/20, avg_epoch_loss=5.12]\n", - "100%|██████████| 50/50 [00:03<00:00, 16.59it/s, epoch=7/20, avg_epoch_loss=4.97]\n", - "100%|██████████| 50/50 [00:03<00:00, 16.27it/s, epoch=8/20, avg_epoch_loss=4.97]\n", - "100%|██████████| 50/50 [00:02<00:00, 19.96it/s, epoch=9/20, avg_epoch_loss=5.11]\n", - "100%|██████████| 50/50 [00:04<00:00, 11.36it/s, epoch=10/20, avg_epoch_loss=4.97]\n", - "100%|██████████| 50/50 [00:03<00:00, 16.62it/s, epoch=11/20, avg_epoch_loss=5.05]\n", - "100%|██████████| 50/50 [00:02<00:00, 17.76it/s, epoch=12/20, avg_epoch_loss=5.04]\n", - "100%|██████████| 50/50 [00:02<00:00, 21.56it/s, epoch=13/20, avg_epoch_loss=4.99]\n", - "100%|██████████| 50/50 [00:02<00:00, 20.64it/s, epoch=14/20, avg_epoch_loss=5.03]\n", - "100%|██████████| 50/50 [00:03<00:00, 13.22it/s, epoch=15/20, avg_epoch_loss=4.97]\n", - "100%|██████████| 50/50 [00:02<00:00, 17.79it/s, epoch=16/20, avg_epoch_loss=4.95]\n", - "100%|██████████| 50/50 [00:02<00:00, 18.29it/s, epoch=17/20, avg_epoch_loss=5.02]\n", - "100%|██████████| 50/50 [00:02<00:00, 17.73it/s, epoch=18/20, avg_epoch_loss=5.02]\n", - "100%|██████████| 50/50 [00:02<00:00, 19.10it/s, epoch=19/20, avg_epoch_loss=5.02]\n", - "100%|██████████| 50/50 [00:03<00:00, 13.29it/s, epoch=20/20, avg_epoch_loss=5]\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(555, 1000, 12)" - ] - }, - "metadata": {}, - "execution_count": 8 - } + "data": { + "text/plain": [ + "(555, 1000, 12)" ] - }, - { - "cell_type": "markdown", - "source": [ - "## 4. Reconciliation\n" - ], - "metadata": { - "id": "W8pbMYBCkWiu" - } - }, - { - "cell_type": "code", - "source": [ - "level = np.arange(1, 100, 2)\n", - "\n", - "#transform the output of DeepAREstimator to a form that is compatible with HierarchicalForecast\n", - "quantiles, forecast_df = samples_to_quantiles_df(samples=forecasts, \n", - " unique_ids=S_df.index, \n", - " dates=Y_test_df['ds'].unique(), \n", - " level=level,\n", - " model_name='DeepAREstimator')\n", - "\n", - "#reconcile forecasts\n", - "reconcilers = [\n", - " BottomUp(),\n", - " MinTrace('ols')\n", - "]\n", - "hrec = HierarchicalReconciliation(reconcilers=reconcilers)\n", - "\n", - "forecast_rec = hrec.reconcile(Y_hat_df=forecast_df, S=S_df, tags=tags, level=level)" - ], - "metadata": { - "id": "5C1fQhy93RWm" - }, - "execution_count": 9, - "outputs": [] - }, + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "estimator = DeepAREstimator(\n", + " freq=\"M\",\n", + " prediction_length=horizon,\n", + " trainer=Trainer(ctx = mx.context.gpu(),\n", + " epochs=20),\n", + ")\n", + "predictor = estimator.train(ds)\n", + "\n", + "forecast_it = predictor.predict(ds, num_samples=1000)\n", + "\n", + "forecasts = list(forecast_it)\n", + "forecasts = np.array([arr.samples for arr in forecasts])\n", + "forecasts.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Reconciliation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "level = np.arange(1, 100, 2)\n", + "\n", + "#transform the output of DeepAREstimator to a form that is compatible with HierarchicalForecast\n", + "quantiles, forecast_df = samples_to_quantiles_df(samples=forecasts, \n", + " unique_ids=S_df.index, \n", + " dates=Y_test_df['ds'].unique(), \n", + " level=level,\n", + " model_name='DeepAREstimator')\n", + "\n", + "#reconcile forecasts\n", + "reconcilers = [\n", + " BottomUp(),\n", + " MinTrace('ols')\n", + "]\n", + "hrec = HierarchicalReconciliation(reconcilers=reconcilers)\n", + "\n", + "forecast_rec = hrec.reconcile(Y_hat_df=forecast_df, S=S_df, tags=tags, level=level)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "source": [ - "forecast_rec" + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
dsDeepAREstimatorDeepAREstimator-medianDeepAREstimator-lo-99DeepAREstimator-lo-97DeepAREstimator-lo-95DeepAREstimator-lo-93DeepAREstimator-lo-91DeepAREstimator-lo-89DeepAREstimator-lo-87...DeepAREstimator/MinTrace_method-ols-hi-81DeepAREstimator/MinTrace_method-ols-hi-83DeepAREstimator/MinTrace_method-ols-hi-85DeepAREstimator/MinTrace_method-ols-hi-87DeepAREstimator/MinTrace_method-ols-hi-89DeepAREstimator/MinTrace_method-ols-hi-91DeepAREstimator/MinTrace_method-ols-hi-93DeepAREstimator/MinTrace_method-ols-hi-95DeepAREstimator/MinTrace_method-ols-hi-97DeepAREstimator/MinTrace_method-ols-hi-99
unique_id
TotalAll2016-01-0143165.92968843002.05859427712.29797930371.24351632741.45874033305.42949234446.46595735164.38041035732.592422...48703.13204648956.75248049233.84383649540.74321949886.82621850286.87792850766.39457751375.71757752240.50636653910.351214
TotalAll2016-02-0120326.79687520469.21093813156.55087915086.48825715738.45703116134.38634316696.16001016828.67643617139.442129...22902.63524423019.41268423146.99711823288.30641123447.65747823631.85799323852.64748524133.20524224531.39011825300.256426
TotalAll2016-03-0124362.20312524237.25097717340.83719718470.07158219132.18061519658.16894519974.22335920339.48358420519.382959...26759.16663426873.89633826999.24353027138.07491227294.63169927475.60218927692.52005527968.15812728359.36068229114.744632
TotalAll2016-04-0129131.66210929236.00878919923.62374021814.11224622685.98750023350.11341823721.05696324168.28620124513.198066...32277.20937032427.58438632591.87563232773.84046432979.03779633216.23391333500.54587733861.82179834374.56687135364.640665
TotalAll2016-05-0122587.77929722638.54101614453.28594716236.98586917163.25180717894.04675818559.20445318789.05306619055.381455...25400.97671625532.98457525677.20890225836.94812226017.08217026225.30659626474.89202926792.04086027242.15804528111.301901
..................................................................
GBDOth2016-08-01-0.300811-0.316894-2.994549-2.208182-2.005075-1.725068-1.620723-1.501304-1.355108...27.15159528.29314129.54033030.92168532.47940534.28003936.43834439.18090843.07332450.589300
GBDOth2016-09-01-0.089410-0.079164-2.981229-2.356738-1.812428-1.499515-1.365453-1.199702-1.120727...24.91208026.03504427.26193228.62080130.15316531.92448934.04766236.74558440.57464047.968273
GBDOth2016-10-01-0.196041-0.207104-2.829650-2.270969-1.674091-1.289834-1.153728-1.078916-1.029915...25.42395826.55097327.78228729.14605930.68395232.46166634.59249937.30015441.14302548.563331
GBDOth2016-11-01-0.315826-0.274183-2.461571-1.829249-1.535889-1.329642-1.260961-1.134465-1.007276...25.12596026.25799127.49478428.86462530.40936132.19498634.33530137.05500540.91497748.368305
GBDOth2016-12-01-0.291579-0.268462-3.987842-2.078746-1.619226-1.385310-1.253607-1.156472-1.092625...26.21609827.31053528.50625529.83060431.32404133.05036635.11960337.74898741.48077248.686579
\n", + "

6660 rows × 305 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " ], - "metadata": { - "id": "V-COhzk_jqk3", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 710 - }, - "outputId": "e3542619-366a-41bf-b120-146f91dc8954" - }, - "execution_count": 10, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " ds DeepAREstimator DeepAREstimator-median \\\n", - "unique_id \n", - "TotalAll 2016-01-01 43165.929688 43002.058594 \n", - "TotalAll 2016-02-01 20326.796875 20469.210938 \n", - "TotalAll 2016-03-01 24362.203125 24237.250977 \n", - "TotalAll 2016-04-01 29131.662109 29236.008789 \n", - "TotalAll 2016-05-01 22587.779297 22638.541016 \n", - "... ... ... ... \n", - "GBDOth 2016-08-01 -0.300811 -0.316894 \n", - "GBDOth 2016-09-01 -0.089410 -0.079164 \n", - "GBDOth 2016-10-01 -0.196041 -0.207104 \n", - "GBDOth 2016-11-01 -0.315826 -0.274183 \n", - "GBDOth 2016-12-01 -0.291579 -0.268462 \n", - "\n", - " DeepAREstimator-lo-99 DeepAREstimator-lo-97 \\\n", - "unique_id \n", - "TotalAll 27712.297979 30371.243516 \n", - "TotalAll 13156.550879 15086.488257 \n", - "TotalAll 17340.837197 18470.071582 \n", - "TotalAll 19923.623740 21814.112246 \n", - "TotalAll 14453.285947 16236.985869 \n", - "... ... ... \n", - "GBDOth -2.994549 -2.208182 \n", - "GBDOth -2.981229 -2.356738 \n", - "GBDOth -2.829650 -2.270969 \n", - "GBDOth -2.461571 -1.829249 \n", - "GBDOth -3.987842 -2.078746 \n", - "\n", - " DeepAREstimator-lo-95 DeepAREstimator-lo-93 \\\n", - "unique_id \n", - "TotalAll 32741.458740 33305.429492 \n", - "TotalAll 15738.457031 16134.386343 \n", - "TotalAll 19132.180615 19658.168945 \n", - "TotalAll 22685.987500 23350.113418 \n", - "TotalAll 17163.251807 17894.046758 \n", - "... ... ... \n", - "GBDOth -2.005075 -1.725068 \n", - "GBDOth -1.812428 -1.499515 \n", - "GBDOth -1.674091 -1.289834 \n", - "GBDOth -1.535889 -1.329642 \n", - "GBDOth -1.619226 -1.385310 \n", - "\n", - " DeepAREstimator-lo-91 DeepAREstimator-lo-89 \\\n", - "unique_id \n", - "TotalAll 34446.465957 35164.380410 \n", - "TotalAll 16696.160010 16828.676436 \n", - "TotalAll 19974.223359 20339.483584 \n", - "TotalAll 23721.056963 24168.286201 \n", - "TotalAll 18559.204453 18789.053066 \n", - "... ... ... \n", - "GBDOth -1.620723 -1.501304 \n", - "GBDOth -1.365453 -1.199702 \n", - "GBDOth -1.153728 -1.078916 \n", - "GBDOth -1.260961 -1.134465 \n", - "GBDOth -1.253607 -1.156472 \n", - "\n", - " DeepAREstimator-lo-87 ... \\\n", - "unique_id ... \n", - "TotalAll 35732.592422 ... \n", - "TotalAll 17139.442129 ... \n", - "TotalAll 20519.382959 ... \n", - "TotalAll 24513.198066 ... \n", - "TotalAll 19055.381455 ... \n", - "... ... ... \n", - "GBDOth -1.355108 ... \n", - "GBDOth -1.120727 ... \n", - "GBDOth -1.029915 ... \n", - "GBDOth -1.007276 ... \n", - "GBDOth -1.092625 ... \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-81 \\\n", - "unique_id \n", - "TotalAll 48703.132046 \n", - "TotalAll 22902.635244 \n", - "TotalAll 26759.166634 \n", - "TotalAll 32277.209370 \n", - "TotalAll 25400.976716 \n", - "... ... \n", - "GBDOth 27.151595 \n", - "GBDOth 24.912080 \n", - "GBDOth 25.423958 \n", - "GBDOth 25.125960 \n", - "GBDOth 26.216098 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-83 \\\n", - "unique_id \n", - "TotalAll 48956.752480 \n", - "TotalAll 23019.412684 \n", - "TotalAll 26873.896338 \n", - "TotalAll 32427.584386 \n", - "TotalAll 25532.984575 \n", - "... ... \n", - "GBDOth 28.293141 \n", - "GBDOth 26.035044 \n", - "GBDOth 26.550973 \n", - "GBDOth 26.257991 \n", - "GBDOth 27.310535 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-85 \\\n", - "unique_id \n", - "TotalAll 49233.843836 \n", - "TotalAll 23146.997118 \n", - "TotalAll 26999.243530 \n", - "TotalAll 32591.875632 \n", - "TotalAll 25677.208902 \n", - "... ... \n", - "GBDOth 29.540330 \n", - "GBDOth 27.261932 \n", - "GBDOth 27.782287 \n", - "GBDOth 27.494784 \n", - "GBDOth 28.506255 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-87 \\\n", - "unique_id \n", - "TotalAll 49540.743219 \n", - "TotalAll 23288.306411 \n", - "TotalAll 27138.074912 \n", - "TotalAll 32773.840464 \n", - "TotalAll 25836.948122 \n", - "... ... \n", - "GBDOth 30.921685 \n", - "GBDOth 28.620801 \n", - "GBDOth 29.146059 \n", - "GBDOth 28.864625 \n", - "GBDOth 29.830604 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-89 \\\n", - "unique_id \n", - "TotalAll 49886.826218 \n", - "TotalAll 23447.657478 \n", - "TotalAll 27294.631699 \n", - "TotalAll 32979.037796 \n", - "TotalAll 26017.082170 \n", - "... ... \n", - "GBDOth 32.479405 \n", - "GBDOth 30.153165 \n", - "GBDOth 30.683952 \n", - "GBDOth 30.409361 \n", - "GBDOth 31.324041 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-91 \\\n", - "unique_id \n", - "TotalAll 50286.877928 \n", - "TotalAll 23631.857993 \n", - "TotalAll 27475.602189 \n", - "TotalAll 33216.233913 \n", - "TotalAll 26225.306596 \n", - "... ... \n", - "GBDOth 34.280039 \n", - "GBDOth 31.924489 \n", - "GBDOth 32.461666 \n", - "GBDOth 32.194986 \n", - "GBDOth 33.050366 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-93 \\\n", - "unique_id \n", - "TotalAll 50766.394577 \n", - "TotalAll 23852.647485 \n", - "TotalAll 27692.520055 \n", - "TotalAll 33500.545877 \n", - "TotalAll 26474.892029 \n", - "... ... \n", - "GBDOth 36.438344 \n", - "GBDOth 34.047662 \n", - "GBDOth 34.592499 \n", - "GBDOth 34.335301 \n", - "GBDOth 35.119603 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-95 \\\n", - "unique_id \n", - "TotalAll 51375.717577 \n", - "TotalAll 24133.205242 \n", - "TotalAll 27968.158127 \n", - "TotalAll 33861.821798 \n", - "TotalAll 26792.040860 \n", - "... ... \n", - "GBDOth 39.180908 \n", - "GBDOth 36.745584 \n", - "GBDOth 37.300154 \n", - "GBDOth 37.055005 \n", - "GBDOth 37.748987 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-97 \\\n", - "unique_id \n", - "TotalAll 52240.506366 \n", - "TotalAll 24531.390118 \n", - "TotalAll 28359.360682 \n", - "TotalAll 34374.566871 \n", - "TotalAll 27242.158045 \n", - "... ... \n", - "GBDOth 43.073324 \n", - "GBDOth 40.574640 \n", - "GBDOth 41.143025 \n", - "GBDOth 40.914977 \n", - "GBDOth 41.480772 \n", - "\n", - " DeepAREstimator/MinTrace_method-ols-hi-99 \n", - "unique_id \n", - "TotalAll 53910.351214 \n", - "TotalAll 25300.256426 \n", - "TotalAll 29114.744632 \n", - "TotalAll 35364.640665 \n", - "TotalAll 28111.301901 \n", - "... ... \n", - "GBDOth 50.589300 \n", - "GBDOth 47.968273 \n", - "GBDOth 48.563331 \n", - "GBDOth 48.368305 \n", - "GBDOth 48.686579 \n", - "\n", - "[6660 rows x 305 columns]" - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dsDeepAREstimatorDeepAREstimator-medianDeepAREstimator-lo-99DeepAREstimator-lo-97DeepAREstimator-lo-95DeepAREstimator-lo-93DeepAREstimator-lo-91DeepAREstimator-lo-89DeepAREstimator-lo-87...DeepAREstimator/MinTrace_method-ols-hi-81DeepAREstimator/MinTrace_method-ols-hi-83DeepAREstimator/MinTrace_method-ols-hi-85DeepAREstimator/MinTrace_method-ols-hi-87DeepAREstimator/MinTrace_method-ols-hi-89DeepAREstimator/MinTrace_method-ols-hi-91DeepAREstimator/MinTrace_method-ols-hi-93DeepAREstimator/MinTrace_method-ols-hi-95DeepAREstimator/MinTrace_method-ols-hi-97DeepAREstimator/MinTrace_method-ols-hi-99
unique_id
TotalAll2016-01-0143165.92968843002.05859427712.29797930371.24351632741.45874033305.42949234446.46595735164.38041035732.592422...48703.13204648956.75248049233.84383649540.74321949886.82621850286.87792850766.39457751375.71757752240.50636653910.351214
TotalAll2016-02-0120326.79687520469.21093813156.55087915086.48825715738.45703116134.38634316696.16001016828.67643617139.442129...22902.63524423019.41268423146.99711823288.30641123447.65747823631.85799323852.64748524133.20524224531.39011825300.256426
TotalAll2016-03-0124362.20312524237.25097717340.83719718470.07158219132.18061519658.16894519974.22335920339.48358420519.382959...26759.16663426873.89633826999.24353027138.07491227294.63169927475.60218927692.52005527968.15812728359.36068229114.744632
TotalAll2016-04-0129131.66210929236.00878919923.62374021814.11224622685.98750023350.11341823721.05696324168.28620124513.198066...32277.20937032427.58438632591.87563232773.84046432979.03779633216.23391333500.54587733861.82179834374.56687135364.640665
TotalAll2016-05-0122587.77929722638.54101614453.28594716236.98586917163.25180717894.04675818559.20445318789.05306619055.381455...25400.97671625532.98457525677.20890225836.94812226017.08217026225.30659626474.89202926792.04086027242.15804528111.301901
..................................................................
GBDOth2016-08-01-0.300811-0.316894-2.994549-2.208182-2.005075-1.725068-1.620723-1.501304-1.355108...27.15159528.29314129.54033030.92168532.47940534.28003936.43834439.18090843.07332450.589300
GBDOth2016-09-01-0.089410-0.079164-2.981229-2.356738-1.812428-1.499515-1.365453-1.199702-1.120727...24.91208026.03504427.26193228.62080130.15316531.92448934.04766236.74558440.57464047.968273
GBDOth2016-10-01-0.196041-0.207104-2.829650-2.270969-1.674091-1.289834-1.153728-1.078916-1.029915...25.42395826.55097327.78228729.14605930.68395232.46166634.59249937.30015441.14302548.563331
GBDOth2016-11-01-0.315826-0.274183-2.461571-1.829249-1.535889-1.329642-1.260961-1.134465-1.007276...25.12596026.25799127.49478428.86462530.40936132.19498634.33530137.05500540.91497748.368305
GBDOth2016-12-01-0.291579-0.268462-3.987842-2.078746-1.619226-1.385310-1.253607-1.156472-1.092625...26.21609827.31053528.50625529.83060431.32404133.05036635.11960337.74898741.48077248.686579
\n", - "

6660 rows × 305 columns

\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 10 - } + "text/plain": [ + " ds DeepAREstimator DeepAREstimator-median \\\n", + "unique_id \n", + "TotalAll 2016-01-01 43165.929688 43002.058594 \n", + "TotalAll 2016-02-01 20326.796875 20469.210938 \n", + "TotalAll 2016-03-01 24362.203125 24237.250977 \n", + "TotalAll 2016-04-01 29131.662109 29236.008789 \n", + "TotalAll 2016-05-01 22587.779297 22638.541016 \n", + "... ... ... ... \n", + "GBDOth 2016-08-01 -0.300811 -0.316894 \n", + "GBDOth 2016-09-01 -0.089410 -0.079164 \n", + "GBDOth 2016-10-01 -0.196041 -0.207104 \n", + "GBDOth 2016-11-01 -0.315826 -0.274183 \n", + "GBDOth 2016-12-01 -0.291579 -0.268462 \n", + "\n", + " DeepAREstimator-lo-99 DeepAREstimator-lo-97 \\\n", + "unique_id \n", + "TotalAll 27712.297979 30371.243516 \n", + "TotalAll 13156.550879 15086.488257 \n", + "TotalAll 17340.837197 18470.071582 \n", + "TotalAll 19923.623740 21814.112246 \n", + "TotalAll 14453.285947 16236.985869 \n", + "... ... ... \n", + "GBDOth -2.994549 -2.208182 \n", + "GBDOth -2.981229 -2.356738 \n", + "GBDOth -2.829650 -2.270969 \n", + "GBDOth -2.461571 -1.829249 \n", + "GBDOth -3.987842 -2.078746 \n", + "\n", + " DeepAREstimator-lo-95 DeepAREstimator-lo-93 \\\n", + "unique_id \n", + "TotalAll 32741.458740 33305.429492 \n", + "TotalAll 15738.457031 16134.386343 \n", + "TotalAll 19132.180615 19658.168945 \n", + "TotalAll 22685.987500 23350.113418 \n", + "TotalAll 17163.251807 17894.046758 \n", + "... ... ... \n", + "GBDOth -2.005075 -1.725068 \n", + "GBDOth -1.812428 -1.499515 \n", + "GBDOth -1.674091 -1.289834 \n", + "GBDOth -1.535889 -1.329642 \n", + "GBDOth -1.619226 -1.385310 \n", + "\n", + " DeepAREstimator-lo-91 DeepAREstimator-lo-89 \\\n", + "unique_id \n", + "TotalAll 34446.465957 35164.380410 \n", + "TotalAll 16696.160010 16828.676436 \n", + "TotalAll 19974.223359 20339.483584 \n", + "TotalAll 23721.056963 24168.286201 \n", + "TotalAll 18559.204453 18789.053066 \n", + "... ... ... \n", + "GBDOth -1.620723 -1.501304 \n", + "GBDOth -1.365453 -1.199702 \n", + "GBDOth -1.153728 -1.078916 \n", + "GBDOth -1.260961 -1.134465 \n", + "GBDOth -1.253607 -1.156472 \n", + "\n", + " DeepAREstimator-lo-87 ... \\\n", + "unique_id ... \n", + "TotalAll 35732.592422 ... \n", + "TotalAll 17139.442129 ... \n", + "TotalAll 20519.382959 ... \n", + "TotalAll 24513.198066 ... \n", + "TotalAll 19055.381455 ... \n", + "... ... ... \n", + "GBDOth -1.355108 ... \n", + "GBDOth -1.120727 ... \n", + "GBDOth -1.029915 ... \n", + "GBDOth -1.007276 ... \n", + "GBDOth -1.092625 ... \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-81 \\\n", + "unique_id \n", + "TotalAll 48703.132046 \n", + "TotalAll 22902.635244 \n", + "TotalAll 26759.166634 \n", + "TotalAll 32277.209370 \n", + "TotalAll 25400.976716 \n", + "... ... \n", + "GBDOth 27.151595 \n", + "GBDOth 24.912080 \n", + "GBDOth 25.423958 \n", + "GBDOth 25.125960 \n", + "GBDOth 26.216098 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-83 \\\n", + "unique_id \n", + "TotalAll 48956.752480 \n", + "TotalAll 23019.412684 \n", + "TotalAll 26873.896338 \n", + "TotalAll 32427.584386 \n", + "TotalAll 25532.984575 \n", + "... ... \n", + "GBDOth 28.293141 \n", + "GBDOth 26.035044 \n", + "GBDOth 26.550973 \n", + "GBDOth 26.257991 \n", + "GBDOth 27.310535 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-85 \\\n", + "unique_id \n", + "TotalAll 49233.843836 \n", + "TotalAll 23146.997118 \n", + "TotalAll 26999.243530 \n", + "TotalAll 32591.875632 \n", + "TotalAll 25677.208902 \n", + "... ... \n", + "GBDOth 29.540330 \n", + "GBDOth 27.261932 \n", + "GBDOth 27.782287 \n", + "GBDOth 27.494784 \n", + "GBDOth 28.506255 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-87 \\\n", + "unique_id \n", + "TotalAll 49540.743219 \n", + "TotalAll 23288.306411 \n", + "TotalAll 27138.074912 \n", + "TotalAll 32773.840464 \n", + "TotalAll 25836.948122 \n", + "... ... \n", + "GBDOth 30.921685 \n", + "GBDOth 28.620801 \n", + "GBDOth 29.146059 \n", + "GBDOth 28.864625 \n", + "GBDOth 29.830604 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-89 \\\n", + "unique_id \n", + "TotalAll 49886.826218 \n", + "TotalAll 23447.657478 \n", + "TotalAll 27294.631699 \n", + "TotalAll 32979.037796 \n", + "TotalAll 26017.082170 \n", + "... ... \n", + "GBDOth 32.479405 \n", + "GBDOth 30.153165 \n", + "GBDOth 30.683952 \n", + "GBDOth 30.409361 \n", + "GBDOth 31.324041 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-91 \\\n", + "unique_id \n", + "TotalAll 50286.877928 \n", + "TotalAll 23631.857993 \n", + "TotalAll 27475.602189 \n", + "TotalAll 33216.233913 \n", + "TotalAll 26225.306596 \n", + "... ... \n", + "GBDOth 34.280039 \n", + "GBDOth 31.924489 \n", + "GBDOth 32.461666 \n", + "GBDOth 32.194986 \n", + "GBDOth 33.050366 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-93 \\\n", + "unique_id \n", + "TotalAll 50766.394577 \n", + "TotalAll 23852.647485 \n", + "TotalAll 27692.520055 \n", + "TotalAll 33500.545877 \n", + "TotalAll 26474.892029 \n", + "... ... \n", + "GBDOth 36.438344 \n", + "GBDOth 34.047662 \n", + "GBDOth 34.592499 \n", + "GBDOth 34.335301 \n", + "GBDOth 35.119603 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-95 \\\n", + "unique_id \n", + "TotalAll 51375.717577 \n", + "TotalAll 24133.205242 \n", + "TotalAll 27968.158127 \n", + "TotalAll 33861.821798 \n", + "TotalAll 26792.040860 \n", + "... ... \n", + "GBDOth 39.180908 \n", + "GBDOth 36.745584 \n", + "GBDOth 37.300154 \n", + "GBDOth 37.055005 \n", + "GBDOth 37.748987 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-97 \\\n", + "unique_id \n", + "TotalAll 52240.506366 \n", + "TotalAll 24531.390118 \n", + "TotalAll 28359.360682 \n", + "TotalAll 34374.566871 \n", + "TotalAll 27242.158045 \n", + "... ... \n", + "GBDOth 43.073324 \n", + "GBDOth 40.574640 \n", + "GBDOth 41.143025 \n", + "GBDOth 40.914977 \n", + "GBDOth 41.480772 \n", + "\n", + " DeepAREstimator/MinTrace_method-ols-hi-99 \n", + "unique_id \n", + "TotalAll 53910.351214 \n", + "TotalAll 25300.256426 \n", + "TotalAll 29114.744632 \n", + "TotalAll 35364.640665 \n", + "TotalAll 28111.301901 \n", + "... ... \n", + "GBDOth 50.589300 \n", + "GBDOth 47.968273 \n", + "GBDOth 48.563331 \n", + "GBDOth 48.368305 \n", + "GBDOth 48.686579 \n", + "\n", + "[6660 rows x 305 columns]" ] - }, - { - "cell_type": "markdown", - "source": [ - "## 5. Evaluation" - ], - "metadata": { - "id": "dC6RcXFcknc4" - } - }, + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "forecast_rec" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Evaluation\n", + "\n", + "To evaluate we use a scaled variation of the CRPS, as proposed by Rangapuram (2021), to measure the accuracy of predicted quantiles `y_hat` compared to the observation `y`.\n", + "\n", + "$$ \\mathrm{sCRPS}(\\hat{F}_{\\tau}, \\mathbf{y}_{\\tau}) = \\frac{2}{N} \\sum_{i}\n", + "\\int^{1}_{0}\n", + "\\frac{\\mathrm{QL}(\\hat{F}_{i,\\tau}, y_{i,\\tau})_{q}}{\\sum_{i} | y_{i,\\tau} |} dq $$\n", + "\n", + "As you can see, HierarchicalForecast results improve on the results of specialized algorithms like [HierE2E](https://proceedings.mlr.press/v139/rangapuram21a.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "source": [ - "rec_model_names = ['DeepAREstimator/MinTrace_method-ols', 'DeepAREstimator/BottomUp']\n", - "\n", - "quantiles = np.array(quantiles[1:]) #remove first quantile (median)\n", - "n_quantiles = len(quantiles)\n", - "n_series = len(S_df)\n", - "\n", - "for name in rec_model_names:\n", - " quantile_columns = [col for col in forecast_rec.columns if (name+'-') in col]\n", - " y_rec = forecast_rec[quantile_columns].values \n", - " y_test = Y_test_df['y'].values\n", - "\n", - " y_rec = y_rec.reshape(n_series, horizon, n_quantiles)\n", - " y_test = y_test.reshape(n_series, horizon)\n", - " scrps = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)\n", - " print(\"{:<40} {:.5f}\".format(name+\":\", scrps))" - ], - "metadata": { - "id": "NvAo1A3WMPqF", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "2d90487e-3820-4632-88be-0b747fae5018" - }, - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DeepAREstimator/MinTrace_method-ols: 0.12632\n", - "DeepAREstimator/BottomUp: 0.13933\n" - ] - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "DeepAREstimator/MinTrace_method-ols: 0.12632\n", + "DeepAREstimator/BottomUp: 0.13933\n" + ] } - ] -} \ No newline at end of file + ], + "source": [ + "rec_model_names = ['DeepAREstimator/MinTrace_method-ols', 'DeepAREstimator/BottomUp']\n", + "\n", + "quantiles = np.array(quantiles[1:]) #remove first quantile (median)\n", + "n_quantiles = len(quantiles)\n", + "n_series = len(S_df)\n", + "\n", + "for name in rec_model_names:\n", + " quantile_columns = [col for col in forecast_rec.columns if (name+'-') in col]\n", + " y_rec = forecast_rec[quantile_columns].values \n", + " y_test = Y_test_df['y'].values\n", + "\n", + " y_rec = y_rec.reshape(n_series, horizon, n_quantiles)\n", + " y_test = y_test.reshape(n_series, horizon)\n", + " scrps = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)\n", + " print(\"{:<40} {:.5f}\".format(name+\":\", scrps))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/nbs/examples/MLFrameworksExample.ipynb b/nbs/examples/MLFrameworksExample.ipynb index be80497..0cc9575 100644 --- a/nbs/examples/MLFrameworksExample.ipynb +++ b/nbs/examples/MLFrameworksExample.ipynb @@ -1,20 +1,34 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "# HierarchicalForecast with ML Frameworks" + "# Neural/MLForecast" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "This is an example notebook which shows how HierarchicalForecast's reconciliation capabilities can be integrated with other libraries, such as NeuralForecast and MLForecast. It first trains some base forecasts with the libraries ML model's and then reconciles the output predictions with HierarchicalForecast's methods." + "This example notebook demonstrates the compatibility of HierarchicalForecast's reconciliation methods with popular machine-learning libraries, specifically [NeuralForecast](https://github.com/Nixtla/neuralforecast) and [MLForecast](https://github.com/Nixtla/mlforecast).\n", + "\n", + "The notebook utilizes NBEATS and XGBRegressor models to create base forecasts for the TourismLarge Hierarchical Dataset. After that, we use HierarchicalForecast to reconcile the base predictions.\n", + "\n", + "**References**
\n", + "- [Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). \"N-BEATS: Neural basis expansion analysis for interpretable time series forecasting\". url: https://arxiv.org/abs/1905.10437](https://arxiv.org/abs/1905.10437)
\n", + "- [Tianqi Chen and Carlos Guestrin. “XGBoost: A Scalable Tree Boosting System”. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. KDD ’16. San Francisco, California, USA: Association for Computing Machinery, 2016, pp. 785–794. isbn: 9781450342322. doi: 10.1145/2939672.2939785. url: https://doi.org/10.1145/2939672.2939785 (cit. on p. 26).](https://doi.org/10.1145/2939672.2939785)
\n", + "\n", + "\n", + "You can run these experiments using CPU or GPU with Google Colab.\n", + "\n", + "\"Open" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -62,6 +76,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -69,6 +84,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -257,6 +273,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -285,6 +302,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -319,6 +337,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1714,6 +1733,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1737,6 +1757,20 @@ "Y_rec_mf = hrec.reconcile(Y_hat_df=Y_hat_mf, Y_df = Y_train_df, S=S_df, tags=tags, level=level)" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Evaluation\n", + "\n", + "To evaluate we use a scaled variation of the CRPS, as proposed by Rangapuram (2021), to measure the accuracy of predicted quantiles `y_hat` compared to the observation `y`.\n", + "\n", + "$$ \\mathrm{sCRPS}(\\hat{F}_{\\tau}, \\mathbf{y}_{\\tau}) = \\frac{2}{N} \\sum_{i}\n", + "\\int^{1}_{0}\n", + "\\frac{\\mathrm{QL}(\\hat{F}_{i,\\tau}, y_{i,\\tau})_{q}}{\\sum_{i} | y_{i,\\tau} |} dq $$" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1782,10 +1816,11 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. Visualizations" + "## 6. Visualizations" ] }, { diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index 9cba170..eae3d48 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -24,8 +24,8 @@ website: - examples/TourismLarge-Evaluation.ipynb - section: ML Forecast Reconciliation contents: - - examples/MLFrameworksExample.ipynb - examples/HierarchicalForecast-GluonTS.ipynb + - examples/MLFrameworksExample.ipynb - Methods - section: API Reference contents: diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index da0ae9c..16028cd 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -11,16 +11,26 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5955e6c8-f4cd-49a6-b6c8-b91c4392a6d3", "metadata": {}, "source": [ - "# Aggregation/Visualization Utils\n", - "\n", - "> The `HierarchicalForecast` package contains utility functions to wrangle and visualize \n", + "# Aggregation/Visualization Utils" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bfc33c1e", + "metadata": {}, + "source": [ + "The `HierarchicalForecast` package contains utility functions to wrangle and visualize \n", "hierarchical series datasets. The `aggregate` function of the module allows you to create\n", "a hierarchy from categorical variables representing the structure levels, returning also\n", - "the aggregation contraints matrix $\\mathbf{S}$." + "the aggregation contraints matrix $\\mathbf{S}$.\n", + "\n", + "In addition, `HierarchicalForecast` ensures compatibility of its reconciliation methods with other popular machine-learning libraries via its external forecast adapters that transform output base forecasts from external libraries into a compatible data frame format." ] }, { @@ -53,7 +63,9 @@ "source": [ "#| hide\n", "from nbdev.showdoc import add_docs, show_doc\n", - "from fastcore.test import test_eq, test_close, test_fail" + "from fastcore.test import test_eq, test_close, test_fail\n", + "\n", + "from statsforecast.utils import generate_series" ] }, { @@ -133,132 +145,7 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0665290c", - "metadata": {}, - "outputs": [], - "source": [ - "#| exporti\n", - "\n", - "# convert levels to output quantile names\n", - "def level_to_outputs(level:Iterable[int]):\n", - " \"\"\" Converts list of levels into output names matching StatsForecast and NeuralForecast methods.\n", - "\n", - " **Parameters:**
\n", - " `level`: int list [0,100]. Probability levels for prediction intervals.
\n", - "\n", - " **Returns:**
\n", - " `output_names`: str list. String list with output column names.\n", - " \"\"\"\n", - " qs = sum([[50-l/2, 50+l/2] for l in level], [])\n", - " output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], [])\n", - "\n", - " sort_idx = np.argsort(qs)\n", - " quantiles = np.array(qs)[sort_idx]\n", - "\n", - " # Add default median\n", - " quantiles = np.concatenate([np.array([50]), quantiles]) / 100\n", - " output_names = list(np.array(output_names)[sort_idx])\n", - " output_names.insert(0, '-median')\n", - " \n", - " return quantiles, output_names\n", - "\n", - "# convert quantiles to output quantile names\n", - "def quantiles_to_outputs(quantiles:Iterable[float]):\n", - " \"\"\"Converts list of quantiles into output names matching StatsForecast and NeuralForecast methods.\n", - "\n", - " **Parameters:**
\n", - " `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
\n", - "\n", - " **Returns:**
\n", - " `output_names`: str list. String list with output column names.\n", - " \"\"\"\n", - " output_names = []\n", - " for q in quantiles:\n", - " if q<.50:\n", - " output_names.append(f'-lo-{np.round(100-200*q,2)}')\n", - " elif q>.50:\n", - " output_names.append(f'-hi-{np.round(100-200*(1-q),2)}')\n", - " else:\n", - " output_names.append('-median')\n", - " return quantiles, output_names" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d4ffbe55", - "metadata": {}, - "outputs": [], - "source": [ - "#| exporti\n", - "\n", - "# given input array of sample forecasts and inptut quantiles/levels, \n", - "# output a Pandas Dataframe with columns of quantile predictions\n", - "def samples_to_quantiles_df(samples:np.ndarray, \n", - " unique_ids:Iterable[str], \n", - " dates:Iterable, \n", - " quantiles:Optional[Iterable[float]] = None,\n", - " level:Optional[Iterable[int]] = None, \n", - " model_name:Optional[str] = \"model\"):\n", - " \"\"\" Transform Samples into HierarchicalForecast input.\n", - " Auxiliary function to create compatible HierarchicalForecast input Y_hat_df dataframe.\n", - "\n", - " **Parameters:**
\n", - " `samples`: numpy array. Samples from forecast distribution of shape [n_series, n_samples, horizon].
\n", - " `unique_ids`: string list. Unique identifiers for each time series.
\n", - " `dates`: datetime list. List of forecast dates.
\n", - " `quantiles`: float list in [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
\n", - " `level`: int list in [0,100]. Probability levels for prediction intervals.
\n", - " `model_name`: string. Name of forecasting model.
\n", - "\n", - " **Returns:**
\n", - " `quantiles`: float list in [0., 1.]. quantiles to estimate from y distribution .
\n", - " `Y_hat_df`: pd.DataFrame. With base quantile forecasts with columns ds and models to reconcile indexed by unique_id.\n", - " \"\"\"\n", - " \n", - " # Get the shape of the array\n", - " n_series, n_samples, horizon = samples.shape\n", - "\n", - " assert n_series == len(unique_ids)\n", - " assert horizon == len(dates)\n", - " assert (quantiles is not None) ^ (level is not None) #check exactly one of quantiles/levels has been input\n", - "\n", - " #create initial dictionary\n", - " forecasts_mean = np.mean(samples, axis=1).flatten()\n", - " unique_ids = np.repeat(unique_ids, horizon)\n", - " ds = np.tile(dates, n_series)\n", - " data = pd.DataFrame({\"unique_id\":unique_ids, \"ds\":ds, model_name:forecasts_mean})\n", - "\n", - " #create quantiles and quantile names\n", - " quantiles, quantile_names = level_to_outputs(level) if level is not None else quantiles_to_outputs(quantiles)\n", - " percentiles = [quantile * 100 for quantile in quantiles]\n", - " col_names = np.array([model_name + quantile_name for quantile_name in quantile_names])\n", - " \n", - " #add quantiles to dataframe\n", - " forecasts_quantiles = np.percentile(samples, percentiles, axis=1)\n", - "\n", - " forecasts_quantiles = np.transpose(forecasts_quantiles, (1,2,0)) # [Q,H,N] -> [N,H,Q]\n", - " forecasts_quantiles = forecasts_quantiles.reshape(-1,len(quantiles))\n", - "\n", - " df = pd.DataFrame(data=forecasts_quantiles, \n", - " columns=col_names)\n", - " \n", - " return quantiles, pd.concat([data,df], axis=1).set_index('unique_id')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "940693d0", - "metadata": {}, - "outputs": [], - "source": [ - "show_doc(samples_to_quantiles_df, title_level=3)" - ] - }, - { + "attachments": {}, "cell_type": "markdown", "id": "3a1f4267", "metadata": {}, @@ -408,14 +295,24 @@ " Y_bottom_df = Y_bottom_df.groupby(['unique_id', 'ds'])['y'].sum().reset_index()\n", " Y_bottom_df.unique_id = Y_bottom_df.unique_id.astype('category')\n", " Y_bottom_df.unique_id = Y_bottom_df.unique_id.cat.set_categories(S_df.columns)\n", - " return Y_bottom_df, S_df, tags\n", - "\n", + " return Y_bottom_df, S_df, tags" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b73d1052", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", "def aggregate(df: pd.DataFrame,\n", " spec: List[List[str]],\n", " is_balanced: bool=False):\n", " \"\"\" Utils Aggregation Function.\n", " Aggregates bottom level series contained in the pd.DataFrame `df` according \n", " to levels defined in the `spec` list applying the `agg_fn` (sum, mean).\n", + "\n", " **Parameters:**
\n", " `df`: pd.DataFrame with columns `['ds', 'y']` and columns to aggregate.
\n", " `spec`: List of levels. Each element of the list contains a list of columns of `df` to aggregate.
\n", @@ -476,12 +373,11 @@ { "cell_type": "code", "execution_count": null, - "id": "f12dc10c-a8e1-41be-8230-7eef645a2550", + "id": "75cea2f7", "metadata": {}, "outputs": [], "source": [ - "#| hide\n", - "from statsforecast.utils import generate_series" + "show_doc(aggregate, title_level=3)" ] }, { @@ -532,80 +428,6 @@ "test_eq(hier_df, hier_df_before)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "d15e9834", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide\n", - "\n", - "#level_to_outputs unit tests\n", - "test_eq(\n", - " level_to_outputs([80, 90]),\n", - " ([0.5 , 0.05, 0.1 , 0.9 , 0.95], ['-median', '-lo-90', '-lo-80', '-hi-80', '-hi-90'])\n", - ")\n", - "test_eq(\n", - " level_to_outputs([30]),\n", - " ([0.5 , 0.35, 0.65], ['-median', '-lo-30', '-hi-30'])\n", - ")\n", - "#quantiles_to_outputs unit tests\n", - "test_eq(\n", - " quantiles_to_outputs([0.2, 0.4, 0.6, 0.8]),\n", - " ([0.2,0.4,0.6, 0.8], ['-lo-60.0', '-lo-20.0', '-hi-20.0', '-hi-60.0'])\n", - ")\n", - "test_eq(\n", - " quantiles_to_outputs([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]),\n", - " ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], \n", - " ['-lo-80.0', '-lo-60.0', '-lo-40.0', '-lo-20.0', '-median', '-hi-20.0', '-hi-40.0', '-hi-60.0', '-hi-80.0'])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c833f6a", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide\n", - "\n", - "#samples_to_quantiles_df unit tests\n", - "start_date = pd.Timestamp(\"2023-06-01\")\n", - "end_date = pd.Timestamp(\"2023-06-10\")\n", - "frequency = \"D\" # Daily frequency\n", - "dates = pd.date_range(start=start_date, end=end_date, freq=frequency).tolist()\n", - "samples = np.random.rand(3, 200, 10)\n", - "unique_ids = ['id1', 'id2', 'id3']\n", - "level = np.array([10, 50, 90])\n", - "quantiles=np.array([0.5, 0.05, 0.25, 0.45, 0.55, 0.75, 0.95])\n", - "\n", - "ret_quantiles_1, ret_df_1 = samples_to_quantiles_df(samples, unique_ids, dates, level=level)\n", - "ret_quantiles_2, ret_df_2 = samples_to_quantiles_df(samples, unique_ids, dates, quantiles=quantiles)\n", - "\n", - "test_eq(\n", - " ret_quantiles_1,\n", - " quantiles\n", - ")\n", - "test_eq(\n", - " ret_df_1.columns,\n", - " ['ds', 'model', 'model-median', 'model-lo-90', 'model-lo-50', 'model-lo-10', 'model-hi-10', 'model-hi-50', 'model-hi-90']\n", - ")\n", - "test_eq(\n", - " ret_df_1.index,\n", - " ['id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1',\n", - " 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2',\n", - " 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3']\n", - ")\n", - "test_eq(\n", - " ret_quantiles_1, ret_quantiles_2\n", - ")\n", - "test_eq(\n", - " ret_df_1.index, ret_df_2.index\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -720,6 +542,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "22febc26-1901-4bef-a181-09ae2f52453b", "metadata": {}, @@ -1120,6 +943,215 @@ " xlabel='Month', ylabel='Predictions',\n", ")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "43fcfbc2", + "metadata": {}, + "source": [ + "# External Forecast Adapters " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c629ab2", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "\n", + "# convert levels to output quantile names\n", + "def level_to_outputs(level:Iterable[int]):\n", + " \"\"\" Converts list of levels into output names matching StatsForecast and NeuralForecast methods.\n", + "\n", + " **Parameters:**
\n", + " `level`: int list [0,100]. Probability levels for prediction intervals.
\n", + "\n", + " **Returns:**
\n", + " `output_names`: str list. String list with output column names.\n", + " \"\"\"\n", + " qs = sum([[50-l/2, 50+l/2] for l in level], [])\n", + " output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], [])\n", + "\n", + " sort_idx = np.argsort(qs)\n", + " quantiles = np.array(qs)[sort_idx]\n", + "\n", + " # Add default median\n", + " quantiles = np.concatenate([np.array([50]), quantiles]) / 100\n", + " output_names = list(np.array(output_names)[sort_idx])\n", + " output_names.insert(0, '-median')\n", + " \n", + " return quantiles, output_names\n", + "\n", + "# convert quantiles to output quantile names\n", + "def quantiles_to_outputs(quantiles:Iterable[float]):\n", + " \"\"\"Converts list of quantiles into output names matching StatsForecast and NeuralForecast methods.\n", + "\n", + " **Parameters:**
\n", + " `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
\n", + "\n", + " **Returns:**
\n", + " `output_names`: str list. String list with output column names.\n", + " \"\"\"\n", + " output_names = []\n", + " for q in quantiles:\n", + " if q<.50:\n", + " output_names.append(f'-lo-{np.round(100-200*q,2)}')\n", + " elif q>.50:\n", + " output_names.append(f'-hi-{np.round(100-200*(1-q),2)}')\n", + " else:\n", + " output_names.append('-median')\n", + " return quantiles, output_names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27b52fff", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "\n", + "# given input array of sample forecasts and inptut quantiles/levels, \n", + "# output a Pandas Dataframe with columns of quantile predictions\n", + "def samples_to_quantiles_df(samples:np.ndarray, \n", + " unique_ids:Iterable[str], \n", + " dates:Iterable, \n", + " quantiles:Optional[Iterable[float]] = None,\n", + " level:Optional[Iterable[int]] = None, \n", + " model_name:Optional[str] = \"model\"):\n", + " \"\"\" Transform Random Samples into HierarchicalForecast input.\n", + " Auxiliary function to create compatible HierarchicalForecast input `Y_hat_df` dataframe.\n", + "\n", + " **Parameters:**
\n", + " `samples`: numpy array. Samples from forecast distribution of shape [n_series, n_samples, horizon].
\n", + " `unique_ids`: string list. Unique identifiers for each time series.
\n", + " `dates`: datetime list. List of forecast dates.
\n", + " `quantiles`: float list in [0., 1.]. Alternative to level, quantiles to estimate from y distribution.
\n", + " `level`: int list in [0,100]. Probability levels for prediction intervals.
\n", + " `model_name`: string. Name of forecasting model.
\n", + "\n", + " **Returns:**
\n", + " `quantiles`: float list in [0., 1.]. quantiles to estimate from y distribution .
\n", + " `Y_hat_df`: pd.DataFrame. With base quantile forecasts with columns ds and models to reconcile indexed by unique_id.\n", + " \"\"\"\n", + " \n", + " # Get the shape of the array\n", + " n_series, n_samples, horizon = samples.shape\n", + "\n", + " assert n_series == len(unique_ids)\n", + " assert horizon == len(dates)\n", + " assert (quantiles is not None) ^ (level is not None) #check exactly one of quantiles/levels has been input\n", + "\n", + " #create initial dictionary\n", + " forecasts_mean = np.mean(samples, axis=1).flatten()\n", + " unique_ids = np.repeat(unique_ids, horizon)\n", + " ds = np.tile(dates, n_series)\n", + " data = pd.DataFrame({\"unique_id\":unique_ids, \"ds\":ds, model_name:forecasts_mean})\n", + "\n", + " #create quantiles and quantile names\n", + " quantiles, quantile_names = level_to_outputs(level) if level is not None else quantiles_to_outputs(quantiles)\n", + " percentiles = [quantile * 100 for quantile in quantiles]\n", + " col_names = np.array([model_name + quantile_name for quantile_name in quantile_names])\n", + " \n", + " #add quantiles to dataframe\n", + " forecasts_quantiles = np.percentile(samples, percentiles, axis=1)\n", + "\n", + " forecasts_quantiles = np.transpose(forecasts_quantiles, (1,2,0)) # [Q,H,N] -> [N,H,Q]\n", + " forecasts_quantiles = forecasts_quantiles.reshape(-1,len(quantiles))\n", + "\n", + " df = pd.DataFrame(data=forecasts_quantiles, \n", + " columns=col_names)\n", + " \n", + " return quantiles, pd.concat([data,df], axis=1).set_index('unique_id')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18e4fcc4", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(samples_to_quantiles_df, title_level=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "500ad055", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "\n", + "#level_to_outputs unit tests\n", + "test_eq(\n", + " level_to_outputs([80, 90]),\n", + " ([0.5 , 0.05, 0.1 , 0.9 , 0.95], ['-median', '-lo-90', '-lo-80', '-hi-80', '-hi-90'])\n", + ")\n", + "test_eq(\n", + " level_to_outputs([30]),\n", + " ([0.5 , 0.35, 0.65], ['-median', '-lo-30', '-hi-30'])\n", + ")\n", + "#quantiles_to_outputs unit tests\n", + "test_eq(\n", + " quantiles_to_outputs([0.2, 0.4, 0.6, 0.8]),\n", + " ([0.2,0.4,0.6, 0.8], ['-lo-60.0', '-lo-20.0', '-hi-20.0', '-hi-60.0'])\n", + ")\n", + "test_eq(\n", + " quantiles_to_outputs([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]),\n", + " ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], \n", + " ['-lo-80.0', '-lo-60.0', '-lo-40.0', '-lo-20.0', '-median', '-hi-20.0', '-hi-40.0', '-hi-60.0', '-hi-80.0'])\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49f4f2ce", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "\n", + "#samples_to_quantiles_df unit tests\n", + "start_date = pd.Timestamp(\"2023-06-01\")\n", + "end_date = pd.Timestamp(\"2023-06-10\")\n", + "frequency = \"D\" # Daily frequency\n", + "dates = pd.date_range(start=start_date, end=end_date, freq=frequency).tolist()\n", + "samples = np.random.rand(3, 200, 10)\n", + "unique_ids = ['id1', 'id2', 'id3']\n", + "level = np.array([10, 50, 90])\n", + "quantiles=np.array([0.5, 0.05, 0.25, 0.45, 0.55, 0.75, 0.95])\n", + "\n", + "ret_quantiles_1, ret_df_1 = samples_to_quantiles_df(samples, unique_ids, dates, level=level)\n", + "ret_quantiles_2, ret_df_2 = samples_to_quantiles_df(samples, unique_ids, dates, quantiles=quantiles)\n", + "\n", + "test_eq(\n", + " ret_quantiles_1,\n", + " quantiles\n", + ")\n", + "test_eq(\n", + " ret_df_1.columns,\n", + " ['ds', 'model', 'model-median', 'model-lo-90', 'model-lo-50', 'model-lo-10', 'model-hi-10', 'model-hi-50', 'model-hi-90']\n", + ")\n", + "test_eq(\n", + " ret_df_1.index,\n", + " ['id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1', 'id1',\n", + " 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2', 'id2',\n", + " 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3', 'id3']\n", + ")\n", + "test_eq(\n", + " ret_quantiles_1, ret_quantiles_2\n", + ")\n", + "test_eq(\n", + " ret_df_1.index, ret_df_2.index\n", + ")" + ] } ], "metadata": {