Skip to content

Commit 3d2fadb

Browse files
authored
Add universal cosine and dot similarity support (#82)
* Implement universal dot and cos similarities * Add similarity testing * Loosen test bounds * Add device tests for similarities * Alias common similarities, generalize hamming similarity * Add tests for hamming distance
1 parent fd57dac commit 3d2fadb

File tree

5 files changed

+484
-20
lines changed

5 files changed

+484
-20
lines changed

torchhd/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
unbind,
1414
bundle,
1515
permute,
16+
cosine_similarity,
17+
dot_similarity,
1618
)
1719

1820
from torchhd.version import __version__
@@ -31,4 +33,6 @@
3133
"unbind",
3234
"bundle",
3335
"permute",
36+
"cosine_similarity",
37+
"dot_similarity",
3438
]

torchhd/functional.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
from torch import BoolTensor, LongTensor, Tensor
44
import torch.nn.functional as F
5-
65
from collections import deque
76

87

@@ -688,6 +687,8 @@ def hard_quantize(input: Tensor):
688687
def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
689688
"""Dot product between the input vector and each vector in others.
690689
690+
Aliased as ``torchhd.dot_similarity``.
691+
691692
Args:
692693
input (Tensor): hypervectors to compare against others
693694
others (Tensor): hypervectors to compare with
@@ -697,6 +698,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
697698
- Others: :math:`(n, d)` or :math:`(d)`
698699
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
699700
701+
.. note::
702+
703+
Output ``dtype`` for ``torch.bool`` is ``torch.long``,
704+
for ``torch.complex64`` is ``torch.float``,
705+
for ``torch.complex128`` is ``torch.double``, otherwise same as input ``dtype``.
706+
700707
Examples::
701708
702709
>>> x = functional.random_hv(3, 6)
@@ -720,6 +727,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
720727
[ 0.6771, -4.2506, 6.0000]])
721728
722729
"""
730+
if input.dtype == torch.bool:
731+
input_as_bipolar = torch.where(input, -1, 1)
732+
others_as_bipolar = torch.where(others, -1, 1)
733+
734+
return F.linear(input_as_bipolar, others_as_bipolar)
735+
723736
if torch.is_complex(input):
724737
return F.linear(input, others.conj()).real
725738

@@ -729,6 +742,8 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
729742
def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
730743
"""Cosine similarity between the input vector and each vector in others.
731744
745+
Aliased as ``torchhd.cosine_similarity``.
746+
732747
Args:
733748
input (Tensor): hypervectors to compare against others
734749
others (Tensor): hypervectors to compare with
@@ -738,6 +753,10 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
738753
- Others: :math:`(n, d)` or :math:`(d)`
739754
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
740755
756+
.. note::
757+
758+
Output ``dtype`` is ``torch.get_default_dtype()``.
759+
741760
Examples::
742761
743762
>>> x = functional.random_hv(3, 6)
@@ -761,43 +780,75 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
761780
[0.1806, 0.2607, 1.0000]])
762781
763782
"""
764-
if torch.is_complex(input):
765-
input_mag = torch.real(input * input.conj()).sum(dim=-1).sqrt()
766-
others_mag = torch.real(others * others.conj()).sum(dim=-1).sqrt()
783+
out_dtype = torch.get_default_dtype()
784+
785+
# calculate vector magnitude
786+
if input.dtype == torch.bool:
787+
input_mag = torch.full(
788+
input.shape[:-1],
789+
math.sqrt(input.size(-1)),
790+
dtype=out_dtype,
791+
device=input.device,
792+
)
793+
others_mag = torch.full(
794+
others.shape[:-1],
795+
math.sqrt(others.size(-1)),
796+
dtype=out_dtype,
797+
device=others.device,
798+
)
799+
800+
elif torch.is_complex(input):
801+
input_dot = torch.real(input * input.conj()).sum(dim=-1, dtype=out_dtype)
802+
input_mag = input_dot.sqrt()
803+
804+
others_dot = torch.real(others * others.conj()).sum(dim=-1, dtype=out_dtype)
805+
others_mag = others_dot.sqrt()
806+
767807
else:
768-
input_mag = torch.sum(input * input, dim=-1).sqrt()
769-
others_mag = torch.sum(others * others, dim=-1).sqrt()
808+
input_dot = torch.sum(input * input, dim=-1, dtype=out_dtype)
809+
input_mag = input_dot.sqrt()
810+
811+
others_dot = torch.sum(others * others, dim=-1, dtype=out_dtype)
812+
others_mag = others_dot.sqrt()
770813

771814
if input.dim() > 1:
772815
magnitude = input_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
773816
else:
774817
magnitude = input_mag * others_mag
775818

776-
return dot_similarity(input, others) / (magnitude + eps)
819+
return dot_similarity(input, others).to(out_dtype) / (magnitude + eps)
777820

778821

779822
def hamming_similarity(input: Tensor, others: Tensor) -> LongTensor:
780-
"""Number of equal elements between the input vector and each vector in others.
823+
"""Number of equal elements between the input vectors and each vector in others.
781824
782825
Args:
783-
input (Tensor): one-dimensional tensor
784-
others (Tensor): two-dimensional tensor
826+
input (Tensor): hypervectors to compare against others
827+
others (Tensor): hypervectors to compare with
785828
786829
Shapes:
787-
- Input: :math:`(d)`
788-
- Others: :math:`(n, d)`
789-
- Output: :math:`(n)`
830+
- Input: :math:`(*, d)`
831+
- Others: :math:`(n, d)` or :math:`(d)`
832+
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
790833
791834
Examples::
792835
793-
>>> x = functional.random_hv(2, 3)
836+
>>> x = functional.random_hv(3, 6)
794837
>>> x
795-
tensor([[ 1., 1., -1.],
796-
[-1., -1., -1.]])
797-
>>> functional.hamming_similarity(x[0], x)
798-
tensor([3., 1.])
838+
tensor([[ 1., 1., -1., -1., 1., 1.],
839+
[ 1., 1., 1., 1., -1., -1.],
840+
[ 1., 1., -1., -1., -1., 1.]])
841+
>>> functional.hamming_similarity(x, x)
842+
tensor([[6, 2, 5],
843+
[2, 6, 3],
844+
[5, 3, 6]])
799845
800846
"""
847+
if input.dim() > 1 and others.dim() > 1:
848+
return torch.sum(
849+
input.unsqueeze(-2) == others.unsqueeze(-3), dim=-1, dtype=torch.long
850+
)
851+
801852
return torch.sum(input == others, dim=-1, dtype=torch.long)
802853

803854

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_value(self, dtype):
6868

6969
abs_sims_diff = sims_diff.abs()
7070
assert torch.all(
71-
(0.248 < abs_sims_diff) & (abs_sims_diff < 0.252)
71+
(0.247 < abs_sims_diff) & (abs_sims_diff < 0.253)
7272
).item(), "similarity changes linearly"
7373
else:
7474
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000

torchhd/tests/basis_hv/test_level_hv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_value(self, dtype):
6767
sims = functional.cosine_similarity(hv[0], hv)
6868
sims_diff = sims[:-1] - sims[1:]
6969
assert torch.all(
70-
(0.248 < sims_diff) & (sims_diff < 0.252)
70+
(0.247 < sims_diff) & (sims_diff < 0.253)
7171
).item(), "similarity decreases linearly"
7272
else:
7373
sims = functional.hamming_similarity(hv[0], hv).float() / 10000

0 commit comments

Comments
 (0)