Skip to content

Add universal cosine and dot similarity support #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
unbind,
bundle,
permute,
cosine_similarity,
dot_similarity,
)

from torchhd.version import __version__
Expand All @@ -31,4 +33,6 @@
"unbind",
"bundle",
"permute",
"cosine_similarity",
"dot_similarity",
]
87 changes: 69 additions & 18 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from torch import BoolTensor, LongTensor, Tensor
import torch.nn.functional as F

from collections import deque


Expand Down Expand Up @@ -688,6 +687,8 @@ def hard_quantize(input: Tensor):
def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
"""Dot product between the input vector and each vector in others.

Aliased as ``torchhd.dot_similarity``.

Args:
input (Tensor): hypervectors to compare against others
others (Tensor): hypervectors to compare with
Expand All @@ -697,6 +698,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
- Others: :math:`(n, d)` or :math:`(d)`
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others

.. note::

Output ``dtype`` for ``torch.bool`` is ``torch.long``,
for ``torch.complex64`` is ``torch.float``,
for ``torch.complex128`` is ``torch.double``, otherwise same as input ``dtype``.

Examples::

>>> x = functional.random_hv(3, 6)
Expand All @@ -720,6 +727,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
[ 0.6771, -4.2506, 6.0000]])

"""
if input.dtype == torch.bool:
input_as_bipolar = torch.where(input, -1, 1)
others_as_bipolar = torch.where(others, -1, 1)

return F.linear(input_as_bipolar, others_as_bipolar)

if torch.is_complex(input):
return F.linear(input, others.conj()).real

Expand All @@ -729,6 +742,8 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
"""Cosine similarity between the input vector and each vector in others.

Aliased as ``torchhd.cosine_similarity``.

Args:
input (Tensor): hypervectors to compare against others
others (Tensor): hypervectors to compare with
Expand All @@ -738,6 +753,10 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
- Others: :math:`(n, d)` or :math:`(d)`
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others

.. note::

Output ``dtype`` is ``torch.get_default_dtype()``.

Examples::

>>> x = functional.random_hv(3, 6)
Expand All @@ -761,43 +780,75 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
[0.1806, 0.2607, 1.0000]])

"""
if torch.is_complex(input):
input_mag = torch.real(input * input.conj()).sum(dim=-1).sqrt()
others_mag = torch.real(others * others.conj()).sum(dim=-1).sqrt()
out_dtype = torch.get_default_dtype()

# calculate vector magnitude
if input.dtype == torch.bool:
input_mag = torch.full(
input.shape[:-1],
math.sqrt(input.size(-1)),
dtype=out_dtype,
device=input.device,
)
others_mag = torch.full(
others.shape[:-1],
math.sqrt(others.size(-1)),
dtype=out_dtype,
device=others.device,
)

elif torch.is_complex(input):
input_dot = torch.real(input * input.conj()).sum(dim=-1, dtype=out_dtype)
input_mag = input_dot.sqrt()

others_dot = torch.real(others * others.conj()).sum(dim=-1, dtype=out_dtype)
others_mag = others_dot.sqrt()

else:
input_mag = torch.sum(input * input, dim=-1).sqrt()
others_mag = torch.sum(others * others, dim=-1).sqrt()
input_dot = torch.sum(input * input, dim=-1, dtype=out_dtype)
input_mag = input_dot.sqrt()

others_dot = torch.sum(others * others, dim=-1, dtype=out_dtype)
others_mag = others_dot.sqrt()

if input.dim() > 1:
magnitude = input_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
else:
magnitude = input_mag * others_mag

return dot_similarity(input, others) / (magnitude + eps)
return dot_similarity(input, others).to(out_dtype) / (magnitude + eps)


def hamming_similarity(input: Tensor, others: Tensor) -> LongTensor:
"""Number of equal elements between the input vector and each vector in others.
"""Number of equal elements between the input vectors and each vector in others.

Args:
input (Tensor): one-dimensional tensor
others (Tensor): two-dimensional tensor
input (Tensor): hypervectors to compare against others
others (Tensor): hypervectors to compare with

Shapes:
- Input: :math:`(d)`
- Others: :math:`(n, d)`
- Output: :math:`(n)`
- Input: :math:`(*, d)`
- Others: :math:`(n, d)` or :math:`(d)`
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others

Examples::

>>> x = functional.random_hv(2, 3)
>>> x = functional.random_hv(3, 6)
>>> x
tensor([[ 1., 1., -1.],
[-1., -1., -1.]])
>>> functional.hamming_similarity(x[0], x)
tensor([3., 1.])
tensor([[ 1., 1., -1., -1., 1., 1.],
[ 1., 1., 1., 1., -1., -1.],
[ 1., 1., -1., -1., -1., 1.]])
>>> functional.hamming_similarity(x, x)
tensor([[6, 2, 5],
[2, 6, 3],
[5, 3, 6]])

"""
if input.dim() > 1 and others.dim() > 1:
return torch.sum(
input.unsqueeze(-2) == others.unsqueeze(-3), dim=-1, dtype=torch.long
)

return torch.sum(input == others, dim=-1, dtype=torch.long)


Expand Down
2 changes: 1 addition & 1 deletion torchhd/tests/basis_hv/test_circular_hv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_value(self, dtype):

abs_sims_diff = sims_diff.abs()
assert torch.all(
(0.248 < abs_sims_diff) & (abs_sims_diff < 0.252)
(0.247 < abs_sims_diff) & (abs_sims_diff < 0.253)
).item(), "similarity changes linearly"
else:
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000
Expand Down
2 changes: 1 addition & 1 deletion torchhd/tests/basis_hv/test_level_hv.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_value(self, dtype):
sims = functional.cosine_similarity(hv[0], hv)
sims_diff = sims[:-1] - sims[1:]
assert torch.all(
(0.248 < sims_diff) & (sims_diff < 0.252)
(0.247 < sims_diff) & (sims_diff < 0.253)
).item(), "similarity decreases linearly"
else:
sims = functional.hamming_similarity(hv[0], hv).float() / 10000
Expand Down
Loading