2
2
import torch
3
3
from torch import BoolTensor , LongTensor , Tensor
4
4
import torch .nn .functional as F
5
-
6
5
from collections import deque
7
6
8
7
@@ -688,6 +687,8 @@ def hard_quantize(input: Tensor):
688
687
def dot_similarity (input : Tensor , others : Tensor ) -> Tensor :
689
688
"""Dot product between the input vector and each vector in others.
690
689
690
+ Aliased as ``torchhd.dot_similarity``.
691
+
691
692
Args:
692
693
input (Tensor): hypervectors to compare against others
693
694
others (Tensor): hypervectors to compare with
@@ -697,6 +698,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
697
698
- Others: :math:`(n, d)` or :math:`(d)`
698
699
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
699
700
701
+ .. note::
702
+
703
+ Output ``dtype`` for ``torch.bool`` is ``torch.long``,
704
+ for ``torch.complex64`` is ``torch.float``,
705
+ for ``torch.complex128`` is ``torch.double``, otherwise same as input ``dtype``.
706
+
700
707
Examples::
701
708
702
709
>>> x = functional.random_hv(3, 6)
@@ -720,6 +727,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
720
727
[ 0.6771, -4.2506, 6.0000]])
721
728
722
729
"""
730
+ if input .dtype == torch .bool :
731
+ input_as_bipolar = torch .where (input , - 1 , 1 )
732
+ others_as_bipolar = torch .where (others , - 1 , 1 )
733
+
734
+ return F .linear (input_as_bipolar , others_as_bipolar )
735
+
723
736
if torch .is_complex (input ):
724
737
return F .linear (input , others .conj ()).real
725
738
@@ -729,6 +742,8 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
729
742
def cosine_similarity (input : Tensor , others : Tensor , * , eps = 1e-08 ) -> Tensor :
730
743
"""Cosine similarity between the input vector and each vector in others.
731
744
745
+ Aliased as ``torchhd.cosine_similarity``.
746
+
732
747
Args:
733
748
input (Tensor): hypervectors to compare against others
734
749
others (Tensor): hypervectors to compare with
@@ -738,6 +753,10 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
738
753
- Others: :math:`(n, d)` or :math:`(d)`
739
754
- Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
740
755
756
+ .. note::
757
+
758
+ Output ``dtype`` is ``torch.get_default_dtype()``.
759
+
741
760
Examples::
742
761
743
762
>>> x = functional.random_hv(3, 6)
@@ -761,43 +780,75 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
761
780
[0.1806, 0.2607, 1.0000]])
762
781
763
782
"""
764
- if torch .is_complex (input ):
765
- input_mag = torch .real (input * input .conj ()).sum (dim = - 1 ).sqrt ()
766
- others_mag = torch .real (others * others .conj ()).sum (dim = - 1 ).sqrt ()
783
+ out_dtype = torch .get_default_dtype ()
784
+
785
+ # calculate vector magnitude
786
+ if input .dtype == torch .bool :
787
+ input_mag = torch .full (
788
+ input .shape [:- 1 ],
789
+ math .sqrt (input .size (- 1 )),
790
+ dtype = out_dtype ,
791
+ device = input .device ,
792
+ )
793
+ others_mag = torch .full (
794
+ others .shape [:- 1 ],
795
+ math .sqrt (others .size (- 1 )),
796
+ dtype = out_dtype ,
797
+ device = others .device ,
798
+ )
799
+
800
+ elif torch .is_complex (input ):
801
+ input_dot = torch .real (input * input .conj ()).sum (dim = - 1 , dtype = out_dtype )
802
+ input_mag = input_dot .sqrt ()
803
+
804
+ others_dot = torch .real (others * others .conj ()).sum (dim = - 1 , dtype = out_dtype )
805
+ others_mag = others_dot .sqrt ()
806
+
767
807
else :
768
- input_mag = torch .sum (input * input , dim = - 1 ).sqrt ()
769
- others_mag = torch .sum (others * others , dim = - 1 ).sqrt ()
808
+ input_dot = torch .sum (input * input , dim = - 1 , dtype = out_dtype )
809
+ input_mag = input_dot .sqrt ()
810
+
811
+ others_dot = torch .sum (others * others , dim = - 1 , dtype = out_dtype )
812
+ others_mag = others_dot .sqrt ()
770
813
771
814
if input .dim () > 1 :
772
815
magnitude = input_mag .unsqueeze (- 1 ) * others_mag .unsqueeze (0 )
773
816
else :
774
817
magnitude = input_mag * others_mag
775
818
776
- return dot_similarity (input , others ) / (magnitude + eps )
819
+ return dot_similarity (input , others ). to ( out_dtype ) / (magnitude + eps )
777
820
778
821
779
822
def hamming_similarity (input : Tensor , others : Tensor ) -> LongTensor :
780
- """Number of equal elements between the input vector and each vector in others.
823
+ """Number of equal elements between the input vectors and each vector in others.
781
824
782
825
Args:
783
- input (Tensor): one-dimensional tensor
784
- others (Tensor): two-dimensional tensor
826
+ input (Tensor): hypervectors to compare against others
827
+ others (Tensor): hypervectors to compare with
785
828
786
829
Shapes:
787
- - Input: :math:`(d)`
788
- - Others: :math:`(n, d)`
789
- - Output: :math:`(n)`
830
+ - Input: :math:`(*, d)`
831
+ - Others: :math:`(n, d)` or :math:`(d)`
832
+ - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
790
833
791
834
Examples::
792
835
793
- >>> x = functional.random_hv(2, 3 )
836
+ >>> x = functional.random_hv(3, 6 )
794
837
>>> x
795
- tensor([[ 1., 1., -1.],
796
- [-1., -1., -1.]])
797
- >>> functional.hamming_similarity(x[0], x)
798
- tensor([3., 1.])
838
+ tensor([[ 1., 1., -1., -1., 1., 1.],
839
+ [ 1., 1., 1., 1., -1., -1.],
840
+ [ 1., 1., -1., -1., -1., 1.]])
841
+ >>> functional.hamming_similarity(x, x)
842
+ tensor([[6, 2, 5],
843
+ [2, 6, 3],
844
+ [5, 3, 6]])
799
845
800
846
"""
847
+ if input .dim () > 1 and others .dim () > 1 :
848
+ return torch .sum (
849
+ input .unsqueeze (- 2 ) == others .unsqueeze (- 3 ), dim = - 1 , dtype = torch .long
850
+ )
851
+
801
852
return torch .sum (input == others , dim = - 1 , dtype = torch .long )
802
853
803
854
0 commit comments