From e6eeb109e7f59324c5ce07f55d58cbab64fc99cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 26 Jul 2024 12:18:37 -0600 Subject: [PATCH] accept plot object in plot_series (#109) --- nbs/plotting.ipynb | 151 ++++++++++++++++++++++++-------------- utilsforecast/plotting.py | 132 ++++++++++++++++++++------------- 2 files changed, 176 insertions(+), 107 deletions(-) diff --git a/nbs/plotting.ipynb b/nbs/plotting.ipynb index e56e7d0..65e757e 100644 --- a/nbs/plotting.ipynb +++ b/nbs/plotting.ipynb @@ -47,7 +47,7 @@ "source": [ "#| export\n", "import re\n", - "from typing import Dict, List, Optional, Union\n", + "from typing import TYPE_CHECKING, Dict, List, Optional, Union\n", "\n", "try:\n", " import matplotlib as mpl\n", @@ -60,6 +60,8 @@ " )\n", "import numpy as np\n", "import pandas as pd\n", + "if TYPE_CHECKING:\n", + " import plotly\n", "from packaging.version import Version, parse as parse_version\n", "\n", "import utilsforecast.processing as ufp\n", @@ -115,6 +117,7 @@ " target_col: str = 'y',\n", " seed: int = 0,\n", " resampler_kwargs: Optional[Dict] = None,\n", + " ax: Optional[Union[plt.Axes, 'plotly.graph_objects.Figure']] = None,\n", "):\n", " \"\"\"Plot forecasts and insample values.\n", "\n", @@ -156,6 +159,8 @@ " For further custumization (\"show_dash\") call the method,\n", " store the plotting object and add the extra arguments to\n", " its `show_dash` method.\n", + " ax : matplotlib axes or plotly Figure, optional (default=None)\n", + " Object where plots will be added.\n", "\n", " Returns\n", " -------\n", @@ -166,6 +171,15 @@ " supported_engines = ['matplotlib', 'plotly', 'plotly-resampler']\n", " if engine not in supported_engines:\n", " raise ValueError(f\"engine must be one of {supported_engines}, got '{engine}'.\")\n", + " if engine.startswith('plotly'):\n", + " try:\n", + " import plotly.graph_objects as go\n", + " from plotly.subplots import make_subplots\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"plotly is not installed. Please install it and try again.\\n\"\n", + " \"You can find detailed instructions at https://github.com/plotly/plotly.py#installation\"\n", + " )\n", " if plot_anomalies:\n", " if level is None:\n", " raise ValueError('In order to plot anomalies you have to specify the `level` argument')\n", @@ -205,11 +219,32 @@ " uids = forecasts_df[id_col].unique()\n", " else:\n", " uids = ids\n", + " if ax is not None:\n", + " if isinstance(ax, np.ndarray) and isinstance(ax.flat[0], plt.Axes):\n", + " gs = ax.flat[0].get_gridspec()\n", + " n_rows, n_cols = gs.nrows, gs.ncols\n", + " ax = ax.reshape(n_rows, n_cols)\n", + " elif engine.startswith('plotly') and isinstance(ax, go.Figure):\n", + " rows, cols = ax._get_subplot_rows_columns()\n", + " # rows and cols are ranges\n", + " n_rows = len(rows)\n", + " n_cols = len(cols)\n", + " else:\n", + " raise ValueError(f'Cannot process `ax` of type: {type(ax).__name__}.')\n", + " max_ids = n_rows * n_cols\n", " if len(uids) > max_ids and plot_random:\n", " rng = np.random.RandomState(seed)\n", " uids = rng.choice(uids, size=max_ids, replace=False)\n", " else:\n", " uids = uids[:max_ids]\n", + " n_series = len(uids)\n", + " if ax is None:\n", + " if n_series == 1:\n", + " n_cols = 1 \n", + " else:\n", + " n_cols = 2\n", + " quot, resid = divmod(n_series, n_cols)\n", + " n_rows = quot + resid\n", "\n", " # filtering\n", " if df is not None:\n", @@ -238,14 +273,6 @@ " else:\n", " df = pl.concat([df, forecasts_df], how='align')\n", "\n", - " # common setup\n", - " n_series = len(uids)\n", - " if n_series == 1:\n", - " n_cols = 1 \n", - " else:\n", - " n_cols = 2\n", - " quot, resid = divmod(n_series, n_cols)\n", - " n_rows = quot + resid\n", " xlabel = f'Time [{time_col}]'\n", " ylabel = f'Target [{target_col}]'\n", " if palette is not None:\n", @@ -265,42 +292,42 @@ " colors = [cm.to_hex(color) for color in rgb_colors] \n", "\n", " # define plot grid\n", - " if engine.startswith('plotly'):\n", - " try:\n", - " import plotly.graph_objects as go\n", - " from plotly.subplots import make_subplots\n", - " except ImportError:\n", - " raise ImportError(\n", - " \"plotly is not installed. Please install it and try again.\\n\"\n", - " \"You can find detailed instructions at https://github.com/plotly/plotly.py#installation\"\n", + " if ax is None:\n", + " postprocess = True\n", + " if engine.startswith('plotly'):\n", + " fig = make_subplots(\n", + " rows=n_rows,\n", + " cols=n_cols,\n", + " vertical_spacing=0.15,\n", + " horizontal_spacing=0.07,\n", + " x_title=xlabel,\n", + " y_title=ylabel,\n", + " subplot_titles=[f'{id_col}={uid}' for uid in uids],\n", + " )\n", + " if engine == \"plotly-resampler\":\n", + " try:\n", + " from plotly_resampler import FigureResampler\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"The 'plotly-resampler' package is required \"\n", + " \"when `engine='plotly-resampler'`.\"\n", + " )\n", + " resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs\n", + " fig = FigureResampler(fig, **resampler_kwargs)\n", + " else:\n", + " fig, ax = plt.subplots(\n", + " nrows=n_rows,\n", + " ncols=n_cols,\n", + " figsize=(16, 3.5 * n_rows),\n", + " squeeze=False,\n", + " constrained_layout=True\n", " )\n", - " fig = make_subplots(\n", - " rows=n_rows,\n", - " cols=n_cols,\n", - " vertical_spacing=0.15,\n", - " horizontal_spacing=0.07,\n", - " x_title=xlabel,\n", - " y_title=ylabel,\n", - " subplot_titles=[f'{id_col}={uid}' for uid in uids],\n", - " )\n", - " if engine == \"plotly-resampler\":\n", - " try:\n", - " from plotly_resampler import FigureResampler\n", - " except ImportError:\n", - " raise ImportError(\n", - " \"plotly-resampler is not installed.\\n\"\n", - " \"Please install it with `pip install plotly-resampler` or `conda install -c conda-forge plotly-resampler`\"\n", - " )\n", - " resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs\n", - " fig = FigureResampler(fig, **resampler_kwargs)\n", " else:\n", - " fig, ax = plt.subplots(\n", - " nrows=n_rows,\n", - " ncols=n_cols,\n", - " figsize=(16, 3.5 * n_rows),\n", - " squeeze=False,\n", - " constrained_layout=True\n", - " )\n", + " postprocess = False\n", + " if engine.startswith('plotly'):\n", + " fig = ax\n", + " else:\n", + " fig = plt.gcf()\n", "\n", " def _add_mpl_plot(axi, df, y_col, levels):\n", " axi.plot(df[time_col], df[y_col], label=y_col, color=color)\n", @@ -403,12 +430,13 @@ " uid_df = ufp.filter_with_mask(df, mask)\n", " row, col = divmod(i, n_cols)\n", " for y_col, color in zip([target_col] + models, colors):\n", - " if engine == 'matplotlib':\n", + " if isinstance(ax, np.ndarray):\n", " _add_mpl_plot(ax[row, col], uid_df, y_col, level)\n", " else:\n", " _add_plotly_plot(fig, uid_df, y_col, level)\n", - " if engine == 'matplotlib':\n", - " ax[row, col].set_title(f\"{id_col}={uid}\")\n", + " title = f\"{id_col}={uid}\"\n", + " if isinstance(ax, np.ndarray):\n", + " ax[row, col].set_title(title)\n", " if col == 0:\n", " ax[row, col].set_ylabel(ylabel)\n", " if row == n_rows - 1:\n", @@ -421,13 +449,15 @@ " labels=xticklabels,\n", " ha=\"right\",\n", " )\n", + " else:\n", + " fig.update_annotations(selector={\"text\": str(i)}, text=title)\n", "\n", - " if engine == 'matplotlib':\n", + " if isinstance(ax, np.ndarray):\n", " handles, labels = ax[0, 0].get_legend_handles_labels()\n", " fig.legend(\n", - " handles, \n", - " labels, \n", - " loc='upper left', \n", + " handles,\n", + " labels,\n", + " loc='upper left',\n", " bbox_to_anchor=(1.01, 0.97),\n", " )\n", " plt.close(fig)\n", @@ -436,10 +466,11 @@ " axi.set_axis_off()\n", " else:\n", " fig.update_xaxes(matches=None, showticklabels=True, visible=True)\n", - " fig.update_layout(margin=dict(l=60, r=10, t=20, b=50))\n", - " fig.update_layout(template=\"plotly_white\", font=dict(size=10))\n", " fig.update_annotations(font_size=10)\n", - " fig.update_layout(autosize=True, height=200 * n_rows)\n", + " if postprocess:\n", + " fig.update_layout(margin=dict(l=60, r=10, t=20, b=50))\n", + " fig.update_layout(template=\"plotly_white\", font=dict(size=10))\n", + " fig.update_layout(autosize=True, height=200 * n_rows)\n", " return fig" ] }, @@ -453,7 +484,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L46){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### plot_series\n", "\n", @@ -467,7 +498,9 @@ "> plot_anomalies:bool=False, engine:str='matplotlib',\n", "> palette:Optional[str]=None, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y', seed:int=0,\n", - "> resampler_kwargs:Optional[Dict]=None)\n", + "> resampler_kwargs:Optional[Dict]=None, ax:Union[matplotlib.ax\n", + "> es._axes.Axes,ForwardRef('plotly.graph_objects.Figure'),None\n", + "> Type]=None)\n", "\n", "*Plot forecasts and insample values.*\n", "\n", @@ -489,12 +522,13 @@ "| target_col | str | y | Column that contains the target. |\n", "| seed | int | 0 | Seed used for the random number generator. Only used if plot_random is True. |\n", "| resampler_kwargs | Optional | None | Keyword arguments to be passed to plotly-resampler constructor.
For further custumization (\"show_dash\") call the method,
store the plotting object and add the extra arguments to
its `show_dash` method. |\n", + "| ax | Union | None | Object where plots will be added. |\n", "| **Returns** | **matplotlib or plotly figure** | | **Plot's figure** |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L46){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### plot_series\n", "\n", @@ -508,7 +542,9 @@ "> plot_anomalies:bool=False, engine:str='matplotlib',\n", "> palette:Optional[str]=None, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y', seed:int=0,\n", - "> resampler_kwargs:Optional[Dict]=None)\n", + "> resampler_kwargs:Optional[Dict]=None, ax:Union[matplotlib.ax\n", + "> es._axes.Axes,ForwardRef('plotly.graph_objects.Figure'),None\n", + "> Type]=None)\n", "\n", "*Plot forecasts and insample values.*\n", "\n", @@ -530,6 +566,7 @@ "| target_col | str | y | Column that contains the target. |\n", "| seed | int | 0 | Seed used for the random number generator. Only used if plot_random is True. |\n", "| resampler_kwargs | Optional | None | Keyword arguments to be passed to plotly-resampler constructor.
For further custumization (\"show_dash\") call the method,
store the plotting object and add the extra arguments to
its `show_dash` method. |\n", + "| ax | Union | None | Object where plots will be added. |\n", "| **Returns** | **matplotlib or plotly figure** | | **Plot's figure** |" ] }, diff --git a/utilsforecast/plotting.py b/utilsforecast/plotting.py index 8a0fc6e..b420eae 100644 --- a/utilsforecast/plotting.py +++ b/utilsforecast/plotting.py @@ -5,7 +5,7 @@ # %% ../nbs/plotting.ipynb 4 import re -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union try: import matplotlib as mpl @@ -18,6 +18,9 @@ ) import numpy as np import pandas as pd + +if TYPE_CHECKING: + import plotly from packaging.version import Version, parse as parse_version import utilsforecast.processing as ufp @@ -59,6 +62,7 @@ def plot_series( target_col: str = "y", seed: int = 0, resampler_kwargs: Optional[Dict] = None, + ax: Optional[Union[plt.Axes, "plotly.graph_objects.Figure"]] = None, ): """Plot forecasts and insample values. @@ -100,6 +104,8 @@ def plot_series( For further custumization ("show_dash") call the method, store the plotting object and add the extra arguments to its `show_dash` method. + ax : matplotlib axes or plotly Figure, optional (default=None) + Object where plots will be added. Returns ------- @@ -110,6 +116,15 @@ def plot_series( supported_engines = ["matplotlib", "plotly", "plotly-resampler"] if engine not in supported_engines: raise ValueError(f"engine must be one of {supported_engines}, got '{engine}'.") + if engine.startswith("plotly"): + try: + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError( + "plotly is not installed. Please install it and try again.\n" + "You can find detailed instructions at https://github.com/plotly/plotly.py#installation" + ) if plot_anomalies: if level is None: raise ValueError( @@ -154,11 +169,32 @@ def plot_series( uids = forecasts_df[id_col].unique() else: uids = ids + if ax is not None: + if isinstance(ax, np.ndarray) and isinstance(ax.flat[0], plt.Axes): + gs = ax.flat[0].get_gridspec() + n_rows, n_cols = gs.nrows, gs.ncols + ax = ax.reshape(n_rows, n_cols) + elif engine.startswith("plotly") and isinstance(ax, go.Figure): + rows, cols = ax._get_subplot_rows_columns() + # rows and cols are ranges + n_rows = len(rows) + n_cols = len(cols) + else: + raise ValueError(f"Cannot process `ax` of type: {type(ax).__name__}.") + max_ids = n_rows * n_cols if len(uids) > max_ids and plot_random: rng = np.random.RandomState(seed) uids = rng.choice(uids, size=max_ids, replace=False) else: uids = uids[:max_ids] + n_series = len(uids) + if ax is None: + if n_series == 1: + n_cols = 1 + else: + n_cols = 2 + quot, resid = divmod(n_series, n_cols) + n_rows = quot + resid # filtering if df is not None: @@ -187,14 +223,6 @@ def plot_series( else: df = pl.concat([df, forecasts_df], how="align") - # common setup - n_series = len(uids) - if n_series == 1: - n_cols = 1 - else: - n_cols = 2 - quot, resid = divmod(n_series, n_cols) - n_rows = quot + resid xlabel = f"Time [{time_col}]" ylabel = f"Target [{target_col}]" if palette is not None: @@ -212,42 +240,42 @@ def plot_series( colors = [cm.to_hex(color) for color in rgb_colors] # define plot grid - if engine.startswith("plotly"): - try: - import plotly.graph_objects as go - from plotly.subplots import make_subplots - except ImportError: - raise ImportError( - "plotly is not installed. Please install it and try again.\n" - "You can find detailed instructions at https://github.com/plotly/plotly.py#installation" + if ax is None: + postprocess = True + if engine.startswith("plotly"): + fig = make_subplots( + rows=n_rows, + cols=n_cols, + vertical_spacing=0.15, + horizontal_spacing=0.07, + x_title=xlabel, + y_title=ylabel, + subplot_titles=[f"{id_col}={uid}" for uid in uids], + ) + if engine == "plotly-resampler": + try: + from plotly_resampler import FigureResampler + except ImportError: + raise ImportError( + "The 'plotly-resampler' package is required " + "when `engine='plotly-resampler'`." + ) + resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs + fig = FigureResampler(fig, **resampler_kwargs) + else: + fig, ax = plt.subplots( + nrows=n_rows, + ncols=n_cols, + figsize=(16, 3.5 * n_rows), + squeeze=False, + constrained_layout=True, ) - fig = make_subplots( - rows=n_rows, - cols=n_cols, - vertical_spacing=0.15, - horizontal_spacing=0.07, - x_title=xlabel, - y_title=ylabel, - subplot_titles=[f"{id_col}={uid}" for uid in uids], - ) - if engine == "plotly-resampler": - try: - from plotly_resampler import FigureResampler - except ImportError: - raise ImportError( - "plotly-resampler is not installed.\n" - "Please install it with `pip install plotly-resampler` or `conda install -c conda-forge plotly-resampler`" - ) - resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs - fig = FigureResampler(fig, **resampler_kwargs) else: - fig, ax = plt.subplots( - nrows=n_rows, - ncols=n_cols, - figsize=(16, 3.5 * n_rows), - squeeze=False, - constrained_layout=True, - ) + postprocess = False + if engine.startswith("plotly"): + fig = ax + else: + fig = plt.gcf() def _add_mpl_plot(axi, df, y_col, levels): axi.plot(df[time_col], df[y_col], label=y_col, color=color) @@ -348,12 +376,13 @@ def _add_plotly_plot(fig, df, y_col, levels): uid_df = ufp.filter_with_mask(df, mask) row, col = divmod(i, n_cols) for y_col, color in zip([target_col] + models, colors): - if engine == "matplotlib": + if isinstance(ax, np.ndarray): _add_mpl_plot(ax[row, col], uid_df, y_col, level) else: _add_plotly_plot(fig, uid_df, y_col, level) - if engine == "matplotlib": - ax[row, col].set_title(f"{id_col}={uid}") + title = f"{id_col}={uid}" + if isinstance(ax, np.ndarray): + ax[row, col].set_title(title) if col == 0: ax[row, col].set_ylabel(ylabel) if row == n_rows - 1: @@ -366,8 +395,10 @@ def _add_plotly_plot(fig, df, y_col, levels): labels=xticklabels, ha="right", ) + else: + fig.update_annotations(selector={"text": str(i)}, text=title) - if engine == "matplotlib": + if isinstance(ax, np.ndarray): handles, labels = ax[0, 0].get_legend_handles_labels() fig.legend( handles, @@ -381,8 +412,9 @@ def _add_plotly_plot(fig, df, y_col, levels): axi.set_axis_off() else: fig.update_xaxes(matches=None, showticklabels=True, visible=True) - fig.update_layout(margin=dict(l=60, r=10, t=20, b=50)) - fig.update_layout(template="plotly_white", font=dict(size=10)) fig.update_annotations(font_size=10) - fig.update_layout(autosize=True, height=200 * n_rows) + if postprocess: + fig.update_layout(margin=dict(l=60, r=10, t=20, b=50)) + fig.update_layout(template="plotly_white", font=dict(size=10)) + fig.update_layout(autosize=True, height=200 * n_rows) return fig