From e32b0a72c8b0cfeba763b52b56bfad5308deb3f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 9 Apr 2024 14:11:22 -0600 Subject: [PATCH] fix scaled_crps for pandas (#74) --- nbs/losses.ipynb | 1 + settings.ini | 2 +- utilsforecast/__init__.py | 2 +- utilsforecast/losses.py | 1 + 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nbs/losses.ipynb b/nbs/losses.ipynb index 5b51bb4..6bc3be9 100644 --- a/nbs/losses.ipynb +++ b/nbs/losses.ipynb @@ -2056,6 +2056,7 @@ " sizes = ufp.counts_by_id(df, id_col)\n", " if isinstance(loss, pd.DataFrame):\n", " loss = loss.set_index(id_col)\n", + " sizes = sizes.set_index(id_col)\n", " assert isinstance(df, pd.DataFrame)\n", " norm = df[target_col].abs().groupby(df[id_col], observed=True).sum()\n", " res = 2 * loss.mul(sizes['counts'], axis=0).div(norm + eps, axis=0)\n", diff --git a/settings.ini b/settings.ini index 06ebad7..814b89b 100644 --- a/settings.ini +++ b/settings.ini @@ -1,7 +1,7 @@ [DEFAULT] repo = utilsforecast lib_name = utilsforecast -version = 0.1.2 +version = 0.1.3 min_python = 3.8 license = apache2 black_formatting = True diff --git a/utilsforecast/__init__.py b/utilsforecast/__init__.py index b3f4756..ae73625 100644 --- a/utilsforecast/__init__.py +++ b/utilsforecast/__init__.py @@ -1 +1 @@ -__version__ = "0.1.2" +__version__ = "0.1.3" diff --git a/utilsforecast/losses.py b/utilsforecast/losses.py index 8d34ea2..2b80d7f 100644 --- a/utilsforecast/losses.py +++ b/utilsforecast/losses.py @@ -627,6 +627,7 @@ def scaled_crps( sizes = ufp.counts_by_id(df, id_col) if isinstance(loss, pd.DataFrame): loss = loss.set_index(id_col) + sizes = sizes.set_index(id_col) assert isinstance(df, pd.DataFrame) norm = df[target_col].abs().groupby(df[id_col], observed=True).sum() res = 2 * loss.mul(sizes["counts"], axis=0).div(norm + eps, axis=0)