From 00a45f52a1407594cdb2c9ea574f09dfc1be006a Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 Jul 2024 15:57:10 +0800 Subject: [PATCH 1/4] refactor: code refactoring; --- pypots/data/saving/pickle.py | 10 ++++++---- pypots/nn/modules/etsformer/layers.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index f9049b1b..4d4e9c93 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -32,12 +32,13 @@ def pickle_dump(data: object, path: str) -> None: create_dir_if_not_exist(extract_parent_dir(path)) with open(path, "wb") as f: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"Successfully saved to {path}") except Exception as e: logger.error( - f"❌ Pickling failed. No cache data saved. Please investigate the error below.\n{e}" + f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}" ) - return None - logger.info(f"Successfully saved to {path}") + + return None def pickle_load(path: str) -> object: @@ -58,6 +59,7 @@ def pickle_load(path: str) -> object: with open(path, "rb") as f: data = pickle.load(f) except Exception as e: - logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}") + logger.error(f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}" return None + return data diff --git a/pypots/nn/modules/etsformer/layers.py b/pypots/nn/modules/etsformer/layers.py index 7fe446cf..1a36ed51 100644 --- a/pypots/nn/modules/etsformer/layers.py +++ b/pypots/nn/modules/etsformer/layers.py @@ -160,8 +160,9 @@ def forward(self, x): f = fft.rfftfreq(t)[self.low_freq :] x_freq, index_tuple = self.topk_freq(x_freq) - f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device) - f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) + device = x_freq.device + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(device) + f = rearrange(f[index_tuple], "b f d -> b f () d").to(device) return self.extrapolate(x_freq, f, t) From fba4a5e78a01ffa626d57a552945941fb1d4926d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 Jul 2024 16:05:29 +0800 Subject: [PATCH 2/4] fix: unclosed '('; --- pypots/data/saving/pickle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index 4d4e9c93..c8ef9129 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -59,7 +59,9 @@ def pickle_load(path: str) -> object: with open(path, "rb") as f: data = pickle.load(f) except Exception as e: - logger.error(f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}" + logger.error( + f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}" + ) return None return data From b6556376023e57a3bd76dbf109af21b30bf3ab1e Mon Sep 17 00:00:00 2001 From: gugababa <93213311+gugababa@users.noreply.github.com> Date: Thu, 25 Jul 2024 00:33:36 -0400 Subject: [PATCH 3/4] Visualize attention matrix in SAITS (#302) * Update visualizeAttention.py Changed typos leading to errors in the function. Returns the figure object instead of axis object. Added a parameter for setting the font scale in the heatmap * Update visualizeAttention.py --------- Co-authored-by: Wenjie Du --- pypots/utils/visual/visualizeAttention.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 pypots/utils/visual/visualizeAttention.py diff --git a/pypots/utils/visual/visualizeAttention.py b/pypots/utils/visual/visualizeAttention.py new file mode 100644 index 00000000..3c6b6899 --- /dev/null +++ b/pypots/utils/visual/visualizeAttention.py @@ -0,0 +1,52 @@ +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from numpy.typing import ArrayLike + + +def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = None): + """Visualize the map of attention weights from Transformer-based models + + Parameters + --------------- + timeSteps: 1D array-like object, preferable list of strings + A vector containing the time steps of the input. + The time steps will be converted to a list of strings if they are not already. + + attention: 2D array-like object + A 2D matrix representing the attention weights + + fontscale: float/int + Sets the scale for fonts in the Seaborn heatmap (applied to sns.set_theme(font_scale = _) + + + Return + --------------- + ax: Matplotlib axes object + + """ + + if not all(isinstance(ele, str) for ele in timeSteps): + timeSteps = [str(step) for step in timeSteps] + + if fontscale is not None: + sns.set_theme(font_scale = fontscale) + + fig, ax = plt.subplots() + ax.tick_params(left=True, bottom=True, labelsize=10) + ax.set_xticks(ax.get_xticks()[::2]) + ax.set_yticks(ax.get_yticks()[::2]) + + assert attention.ndim == 2, "The attention matrix is not two-dimensional" + sns.heatmap( + attention, + ax=ax, + xticklabels=timeSteps, + yticklabels=timeSteps, + linewidths=0, + cbar=True, + ) + cb = ax.collections[0].colorbar + cb.ax.tick_params(labelsize=10) + + return fig From 2ddd2887abe3d378430acfa2b54a5740987f2069 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 25 Jul 2024 23:45:18 +0800 Subject: [PATCH 4/4] refactor: rename into attention_map; --- ...visualizeAttention.py => attention_map.py} | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) rename pypots/utils/visual/{visualizeAttention.py => attention_map.py} (75%) diff --git a/pypots/utils/visual/visualizeAttention.py b/pypots/utils/visual/attention_map.py similarity index 75% rename from pypots/utils/visual/visualizeAttention.py rename to pypots/utils/visual/attention_map.py index 3c6b6899..7eaf6dd4 100644 --- a/pypots/utils/visual/visualizeAttention.py +++ b/pypots/utils/visual/attention_map.py @@ -1,16 +1,27 @@ +""" +Utilities for attention map visualization. +""" + +# Created by Anshuman Swain and Wenjie Du +# License: BSD-3-Clause + import matplotlib.pyplot as plt import numpy as np -import seaborn as sns from numpy.typing import ArrayLike +try: + import seaborn as sns +except Exception: + pass + -def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = None): - """Visualize the map of attention weights from Transformer-based models +def plot_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale=None): + """Visualize the map of attention weights from Transformer-based models. Parameters --------------- timeSteps: 1D array-like object, preferable list of strings - A vector containing the time steps of the input. + A vector containing the time steps of the input. The time steps will be converted to a list of strings if they are not already. attention: 2D array-like object @@ -30,7 +41,7 @@ def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = timeSteps = [str(step) for step in timeSteps] if fontscale is not None: - sns.set_theme(font_scale = fontscale) + sns.set_theme(font_scale=fontscale) fig, ax = plt.subplots() ax.tick_params(left=True, bottom=True, labelsize=10)