Skip to content

Commit

Permalink
accept plot object in plot_series (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jul 26, 2024
1 parent 11871b9 commit e6eeb10
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 107 deletions.
151 changes: 94 additions & 57 deletions nbs/plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"source": [
"#| export\n",
"import re\n",
"from typing import Dict, List, Optional, Union\n",
"from typing import TYPE_CHECKING, Dict, List, Optional, Union\n",
"\n",
"try:\n",
" import matplotlib as mpl\n",
Expand All @@ -60,6 +60,8 @@
" )\n",
"import numpy as np\n",
"import pandas as pd\n",
"if TYPE_CHECKING:\n",
" import plotly\n",
"from packaging.version import Version, parse as parse_version\n",
"\n",
"import utilsforecast.processing as ufp\n",
Expand Down Expand Up @@ -115,6 +117,7 @@
" target_col: str = 'y',\n",
" seed: int = 0,\n",
" resampler_kwargs: Optional[Dict] = None,\n",
" ax: Optional[Union[plt.Axes, 'plotly.graph_objects.Figure']] = None,\n",
"):\n",
" \"\"\"Plot forecasts and insample values.\n",
"\n",
Expand Down Expand Up @@ -156,6 +159,8 @@
" For further custumization (\"show_dash\") call the method,\n",
" store the plotting object and add the extra arguments to\n",
" its `show_dash` method.\n",
" ax : matplotlib axes or plotly Figure, optional (default=None)\n",
" Object where plots will be added.\n",
"\n",
" Returns\n",
" -------\n",
Expand All @@ -166,6 +171,15 @@
" supported_engines = ['matplotlib', 'plotly', 'plotly-resampler']\n",
" if engine not in supported_engines:\n",
" raise ValueError(f\"engine must be one of {supported_engines}, got '{engine}'.\")\n",
" if engine.startswith('plotly'):\n",
" try:\n",
" import plotly.graph_objects as go\n",
" from plotly.subplots import make_subplots\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"plotly is not installed. Please install it and try again.\\n\"\n",
" \"You can find detailed instructions at https://github.com/plotly/plotly.py#installation\"\n",
" )\n",
" if plot_anomalies:\n",
" if level is None:\n",
" raise ValueError('In order to plot anomalies you have to specify the `level` argument')\n",
Expand Down Expand Up @@ -205,11 +219,32 @@
" uids = forecasts_df[id_col].unique()\n",
" else:\n",
" uids = ids\n",
" if ax is not None:\n",
" if isinstance(ax, np.ndarray) and isinstance(ax.flat[0], plt.Axes):\n",
" gs = ax.flat[0].get_gridspec()\n",
" n_rows, n_cols = gs.nrows, gs.ncols\n",
" ax = ax.reshape(n_rows, n_cols)\n",
" elif engine.startswith('plotly') and isinstance(ax, go.Figure):\n",
" rows, cols = ax._get_subplot_rows_columns()\n",
" # rows and cols are ranges\n",
" n_rows = len(rows)\n",
" n_cols = len(cols)\n",
" else:\n",
" raise ValueError(f'Cannot process `ax` of type: {type(ax).__name__}.')\n",
" max_ids = n_rows * n_cols\n",
" if len(uids) > max_ids and plot_random:\n",
" rng = np.random.RandomState(seed)\n",
" uids = rng.choice(uids, size=max_ids, replace=False)\n",
" else:\n",
" uids = uids[:max_ids]\n",
" n_series = len(uids)\n",
" if ax is None:\n",
" if n_series == 1:\n",
" n_cols = 1 \n",
" else:\n",
" n_cols = 2\n",
" quot, resid = divmod(n_series, n_cols)\n",
" n_rows = quot + resid\n",
"\n",
" # filtering\n",
" if df is not None:\n",
Expand Down Expand Up @@ -238,14 +273,6 @@
" else:\n",
" df = pl.concat([df, forecasts_df], how='align')\n",
"\n",
" # common setup\n",
" n_series = len(uids)\n",
" if n_series == 1:\n",
" n_cols = 1 \n",
" else:\n",
" n_cols = 2\n",
" quot, resid = divmod(n_series, n_cols)\n",
" n_rows = quot + resid\n",
" xlabel = f'Time [{time_col}]'\n",
" ylabel = f'Target [{target_col}]'\n",
" if palette is not None:\n",
Expand All @@ -265,42 +292,42 @@
" colors = [cm.to_hex(color) for color in rgb_colors] \n",
"\n",
" # define plot grid\n",
" if engine.startswith('plotly'):\n",
" try:\n",
" import plotly.graph_objects as go\n",
" from plotly.subplots import make_subplots\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"plotly is not installed. Please install it and try again.\\n\"\n",
" \"You can find detailed instructions at https://github.com/plotly/plotly.py#installation\"\n",
" if ax is None:\n",
" postprocess = True\n",
" if engine.startswith('plotly'):\n",
" fig = make_subplots(\n",
" rows=n_rows,\n",
" cols=n_cols,\n",
" vertical_spacing=0.15,\n",
" horizontal_spacing=0.07,\n",
" x_title=xlabel,\n",
" y_title=ylabel,\n",
" subplot_titles=[f'{id_col}={uid}' for uid in uids],\n",
" )\n",
" if engine == \"plotly-resampler\":\n",
" try:\n",
" from plotly_resampler import FigureResampler\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"The 'plotly-resampler' package is required \"\n",
" \"when `engine='plotly-resampler'`.\"\n",
" )\n",
" resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs\n",
" fig = FigureResampler(fig, **resampler_kwargs)\n",
" else:\n",
" fig, ax = plt.subplots(\n",
" nrows=n_rows,\n",
" ncols=n_cols,\n",
" figsize=(16, 3.5 * n_rows),\n",
" squeeze=False,\n",
" constrained_layout=True\n",
" )\n",
" fig = make_subplots(\n",
" rows=n_rows,\n",
" cols=n_cols,\n",
" vertical_spacing=0.15,\n",
" horizontal_spacing=0.07,\n",
" x_title=xlabel,\n",
" y_title=ylabel,\n",
" subplot_titles=[f'{id_col}={uid}' for uid in uids],\n",
" )\n",
" if engine == \"plotly-resampler\":\n",
" try:\n",
" from plotly_resampler import FigureResampler\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"plotly-resampler is not installed.\\n\"\n",
" \"Please install it with `pip install plotly-resampler` or `conda install -c conda-forge plotly-resampler`\"\n",
" )\n",
" resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs\n",
" fig = FigureResampler(fig, **resampler_kwargs)\n",
" else:\n",
" fig, ax = plt.subplots(\n",
" nrows=n_rows,\n",
" ncols=n_cols,\n",
" figsize=(16, 3.5 * n_rows),\n",
" squeeze=False,\n",
" constrained_layout=True\n",
" )\n",
" postprocess = False\n",
" if engine.startswith('plotly'):\n",
" fig = ax\n",
" else:\n",
" fig = plt.gcf()\n",
"\n",
" def _add_mpl_plot(axi, df, y_col, levels):\n",
" axi.plot(df[time_col], df[y_col], label=y_col, color=color)\n",
Expand Down Expand Up @@ -403,12 +430,13 @@
" uid_df = ufp.filter_with_mask(df, mask)\n",
" row, col = divmod(i, n_cols)\n",
" for y_col, color in zip([target_col] + models, colors):\n",
" if engine == 'matplotlib':\n",
" if isinstance(ax, np.ndarray):\n",
" _add_mpl_plot(ax[row, col], uid_df, y_col, level)\n",
" else:\n",
" _add_plotly_plot(fig, uid_df, y_col, level)\n",
" if engine == 'matplotlib':\n",
" ax[row, col].set_title(f\"{id_col}={uid}\")\n",
" title = f\"{id_col}={uid}\"\n",
" if isinstance(ax, np.ndarray):\n",
" ax[row, col].set_title(title)\n",
" if col == 0:\n",
" ax[row, col].set_ylabel(ylabel)\n",
" if row == n_rows - 1:\n",
Expand All @@ -421,13 +449,15 @@
" labels=xticklabels,\n",
" ha=\"right\",\n",
" )\n",
" else:\n",
" fig.update_annotations(selector={\"text\": str(i)}, text=title)\n",
"\n",
" if engine == 'matplotlib':\n",
" if isinstance(ax, np.ndarray):\n",
" handles, labels = ax[0, 0].get_legend_handles_labels()\n",
" fig.legend(\n",
" handles, \n",
" labels, \n",
" loc='upper left', \n",
" handles,\n",
" labels,\n",
" loc='upper left',\n",
" bbox_to_anchor=(1.01, 0.97),\n",
" )\n",
" plt.close(fig)\n",
Expand All @@ -436,10 +466,11 @@
" axi.set_axis_off()\n",
" else:\n",
" fig.update_xaxes(matches=None, showticklabels=True, visible=True)\n",
" fig.update_layout(margin=dict(l=60, r=10, t=20, b=50))\n",
" fig.update_layout(template=\"plotly_white\", font=dict(size=10))\n",
" fig.update_annotations(font_size=10)\n",
" fig.update_layout(autosize=True, height=200 * n_rows)\n",
" if postprocess:\n",
" fig.update_layout(margin=dict(l=60, r=10, t=20, b=50))\n",
" fig.update_layout(template=\"plotly_white\", font=dict(size=10))\n",
" fig.update_layout(autosize=True, height=200 * n_rows)\n",
" return fig"
]
},
Expand All @@ -453,7 +484,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L46){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### plot_series\n",
"\n",
Expand All @@ -467,7 +498,9 @@
"> plot_anomalies:bool=False, engine:str='matplotlib',\n",
"> palette:Optional[str]=None, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y', seed:int=0,\n",
"> resampler_kwargs:Optional[Dict]=None)\n",
"> resampler_kwargs:Optional[Dict]=None, ax:Union[matplotlib.ax\n",
"> es._axes.Axes,ForwardRef('plotly.graph_objects.Figure'),None\n",
"> Type]=None)\n",
"\n",
"*Plot forecasts and insample values.*\n",
"\n",
Expand All @@ -489,12 +522,13 @@
"| target_col | str | y | Column that contains the target. |\n",
"| seed | int | 0 | Seed used for the random number generator. Only used if plot_random is True. |\n",
"| resampler_kwargs | Optional | None | Keyword arguments to be passed to plotly-resampler constructor.<br>For further custumization (\"show_dash\") call the method,<br>store the plotting object and add the extra arguments to<br>its `show_dash` method. |\n",
"| ax | Union | None | Object where plots will be added. |\n",
"| **Returns** | **matplotlib or plotly figure** | | **Plot's figure** |"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L46){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/plotting.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### plot_series\n",
"\n",
Expand All @@ -508,7 +542,9 @@
"> plot_anomalies:bool=False, engine:str='matplotlib',\n",
"> palette:Optional[str]=None, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y', seed:int=0,\n",
"> resampler_kwargs:Optional[Dict]=None)\n",
"> resampler_kwargs:Optional[Dict]=None, ax:Union[matplotlib.ax\n",
"> es._axes.Axes,ForwardRef('plotly.graph_objects.Figure'),None\n",
"> Type]=None)\n",
"\n",
"*Plot forecasts and insample values.*\n",
"\n",
Expand All @@ -530,6 +566,7 @@
"| target_col | str | y | Column that contains the target. |\n",
"| seed | int | 0 | Seed used for the random number generator. Only used if plot_random is True. |\n",
"| resampler_kwargs | Optional | None | Keyword arguments to be passed to plotly-resampler constructor.<br>For further custumization (\"show_dash\") call the method,<br>store the plotting object and add the extra arguments to<br>its `show_dash` method. |\n",
"| ax | Union | None | Object where plots will be added. |\n",
"| **Returns** | **matplotlib or plotly figure** | | **Plot's figure** |"
]
},
Expand Down
Loading

0 comments on commit e6eeb10

Please sign in to comment.