|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +from ..utils import multirank, index_l_of_d |
| 4 | + |
| 5 | +__all__ = ['plot', 'plot_std_of_each_dim', 'plot_norm_of_each_vec'] |
| 6 | + |
| 7 | + |
| 8 | +def plot(X, ax=None, bins=50, **kwargs): |
| 9 | + if ax is None: |
| 10 | + fig, ax = plt.subplots() |
| 11 | + |
| 12 | + if multirank.is_matrix(X): # X [C, N] |
| 13 | + C = len(X) |
| 14 | + for c in range(C): |
| 15 | + _kwargs = index_l_of_d(kwargs, c) |
| 16 | + ax.hist(X[c], bins=bins, **_kwargs) |
| 17 | + |
| 18 | + else: # X [N] |
| 19 | + ax.hist(X, bins=bins, **kwargs) |
| 20 | + return ax |
| 21 | + |
| 22 | + |
| 23 | +def plot_std_of_each_dim(X, ax=None, bins=50, **kwargs): |
| 24 | + if multirank.is_multirank(X): # X [C, N, D] |
| 25 | + C = len(X) |
| 26 | + stds = [np.std(X[c], axis=0) for c in range(C)] |
| 27 | + else: # X [C, N] |
| 28 | + stds = np.std(X, axis=0) |
| 29 | + |
| 30 | + return plot(stds, ax=ax, bins=bins, **kwargs) |
| 31 | + |
| 32 | + |
| 33 | +def plot_norm_of_each_vec(X, ax=None, bins=50, **kwargs): |
| 34 | + if multirank.is_multirank(X): # X [C, N, D] |
| 35 | + C = len(X) |
| 36 | + norms = [np.linalg.norm(X[c], axis=-1) for c in range(C)] |
| 37 | + else: # X [C, N] |
| 38 | + norms = np.linalg.norm(X, axis=-1) |
| 39 | + |
| 40 | + return plot(norms, ax=ax, bins=bins, **kwargs) |
0 commit comments