|
| 1 | +import torch |
| 2 | +from torch import Tensor |
| 3 | + |
| 4 | +import torchhd.functional as functional |
| 5 | + |
| 6 | + |
| 7 | +def plot_pair_similarity(memory: Tensor, ax=None, **kwargs): |
| 8 | + """Plots the pair-wise similarity of a hypervector set. |
| 9 | +
|
| 10 | + Args: |
| 11 | + memory (Tensor): The set of :math:`n` hypervectors whose pair-wise similarity is to be displayed. |
| 12 | + ax (matplotlib.axes, optional): Axes in which to draw the plot. |
| 13 | +
|
| 14 | + Other Parameters: |
| 15 | + **kwargs: `matplotlib.axes.Axes.pcolormesh <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.pcolormesh.html>`_ arguments. |
| 16 | +
|
| 17 | + Returns: |
| 18 | + matplotlib.collections.QuadMesh: `matplotlib.collections.QuadMesh <https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.QuadMesh>`_. |
| 19 | +
|
| 20 | + Shapes: |
| 21 | + - Memory: :math:`(n, d)` |
| 22 | +
|
| 23 | + Examples:: |
| 24 | +
|
| 25 | + >>> import matplotlib.pyplot as plt |
| 26 | + >>> hv = torchhd.level_hv(10, 10000) |
| 27 | + >>> utils.plot_pair_similarity(hv) |
| 28 | + >>> plt.show() |
| 29 | + """ |
| 30 | + try: |
| 31 | + import matplotlib.pyplot as plt |
| 32 | + except ImportError: |
| 33 | + raise ImportError( |
| 34 | + "Install matplotlib to use plotting functionality. \ |
| 35 | + See https://matplotlib.org/stable/users/installing/index.html for more information." |
| 36 | + ) |
| 37 | + |
| 38 | + similarity = [] |
| 39 | + for vector in memory: |
| 40 | + similarity.append(functional.cosine_similarity(vector, memory).tolist()) |
| 41 | + |
| 42 | + if ax is None: |
| 43 | + ax = plt.gca() |
| 44 | + |
| 45 | + xy = torch.arange(memory.size(-2)) |
| 46 | + x, y = torch.meshgrid(xy, xy) |
| 47 | + |
| 48 | + ax.set_aspect("equal", adjustable="box") |
| 49 | + return ax.pcolormesh(x, y, similarity, **kwargs) |
| 50 | + |
| 51 | + |
| 52 | +def plot_similarity(input: Tensor, memory: Tensor, ax=None, **kwargs): |
| 53 | + """Plots the similarity of an one hypervector with a set of hypervectors. |
| 54 | +
|
| 55 | + Args: |
| 56 | + input (torch.Tensor): Hypervector to compare against the memory. |
| 57 | + memory (torch.Tensor): Set of :math:`n` hypervectors. |
| 58 | + ax (matplotlib.axes, optional): Axes in which to draw the plot. |
| 59 | +
|
| 60 | + Other Parameters: |
| 61 | + **kwargs: `matplotlib.axes.Axes.stem <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.stem.html?highlight=stem#matplotlib.axes.Axes.stem>`_ arguments. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + StemContainer: `matplotlib.container.StemContainer <https://matplotlib.org/stable/api/container_api.html#matplotlib.container.StemContainer>`_. |
| 65 | +
|
| 66 | + Shapes: |
| 67 | + - Input: :math:`(d)` |
| 68 | + - Memory: :math:`(n, d)` |
| 69 | +
|
| 70 | + Examples:: |
| 71 | +
|
| 72 | + >>> import matplotlib.pyplot as plt |
| 73 | + >>> hv = torchhd.level_hv(10, 10000) |
| 74 | + >>> utils.plot_similarity(hv[4], hv) |
| 75 | + >>> plt.show() |
| 76 | +
|
| 77 | + """ |
| 78 | + try: |
| 79 | + import matplotlib.pyplot as plt |
| 80 | + except ImportError: |
| 81 | + raise ImportError( |
| 82 | + "Install matplotlib to use plotting functionality. \ |
| 83 | + See https://matplotlib.org/stable/users/installing/index.html for more information." |
| 84 | + ) |
| 85 | + |
| 86 | + similarity = functional.cosine_similarity(input, memory).tolist() |
| 87 | + |
| 88 | + if ax is None: |
| 89 | + ax = plt.gca() |
| 90 | + |
| 91 | + return ax.stem(similarity, **kwargs) |
0 commit comments