From 6408ee0d51f2721ed53a9df937fb724b75d15f6d Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sun, 11 Aug 2024 01:18:55 -0400 Subject: [PATCH] minor tweak to dim-reduce in utils --- ngclearn/utils/viz/dim_reduce.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py index 98646084..4f9095dd 100755 --- a/ngclearn/utils/viz/dim_reduce.py +++ b/ngclearn/utils/viz/dim_reduce.py @@ -1,6 +1,6 @@ import matplotlib import matplotlib.pyplot as plt -cmap = plt.cm.jet +default_cmap = plt.cm.jet import numpy as np from sklearn.decomposition import IncrementalPCA @@ -66,7 +66,8 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping z_2D = vectors return z_2D -def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.): +def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., + cmap=None): """ Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes (produced by either PCA or t-SNE). @@ -80,7 +81,9 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.): plot_fname: /path/to/plot_fname. for saving the plot to disk - alpha: + alpha: alpha intensity level to present colors in scatterplot + + cmap: custom color-map to provide """ curr_backend = plt.rcParams["backend"] matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting @@ -92,7 +95,11 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.): if lab.shape[1] > 1: ## extract integer class labels from a one-hot matrix lab = np.argmax(lab, 1) plt.figure(figsize=(8, 6)) - plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=cmap, alpha=alpha) + _cmap = cmap + if _cmap is None: + _cmap = default_cmap + #print("> USING DEFAULT CMAP!") + plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha) colorbar = plt.colorbar() #colorbar.set_alpha(1) #plt.draw_all()