Skip to content

Commit f7eb84a

Browse files
Add Holographic Reduced Representations VSA model (#109)
* Add HRR vsa model * [github-action] formatting fixes Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent bfd95e6 commit f7eb84a

File tree

6 files changed

+262
-17
lines changed

6 files changed

+262
-17
lines changed

docs/torchhd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ VSA Models
8080
VSA_Model
8181
BSC
8282
MAP
83-
.. HRR
83+
HRR
8484
FHRR
8585

8686

torchhd/functional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchhd.base import VSA_Model
88
from torchhd.bsc import BSC
99
from torchhd.map import MAP
10+
from torchhd.hrr import HRR
1011
from torchhd.fhrr import FHRR
1112

1213

@@ -484,6 +485,11 @@ def circular_hv(
484485
[-0.887-0.460j, -0.906+0.421j, -0.727-0.686j, -0.271+0.962j, -0.705-0.709j, 0.562-0.827j]])
485486
486487
"""
488+
if model == HRR:
489+
raise ValueError(
490+
"The circular hypervectors don't currently work with the HRR model. We are not sure why, if you have any insight that could help please share it at: https://github.com/hyperdimensional-computing/torchhd/issues/108."
491+
)
492+
487493
# convert from normalized "randomness" variable r to
488494
# number of levels between orthogonal pairs or "span"
489495
levels_per_span = ((1 - randomness) * (num_vectors / 2) + randomness * 1) * 2

torchhd/hrr.py

Lines changed: 227 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99

1010
class HRR(VSA_Model):
11+
"""Holographic Reduced Representation
12+
13+
Proposed in `Holographic reduced representations <https://ieeexplore.ieee.org/document/377968>`_, this model uses real valued hypervectors.
14+
"""
15+
1116
supported_dtypes: Set[torch.dtype] = {torch.float32, torch.float64}
1217

1318
@classmethod
@@ -20,6 +25,30 @@ def empty_hv(
2025
device=None,
2126
requires_grad=False,
2227
) -> "HRR":
28+
"""Creates a set of hypervectors representing empty sets.
29+
30+
When bundled with a random-hypervector :math:`x`, the result is :math:`x`.
31+
The empty vector of the HRR model is simply a set of 0 values.
32+
33+
Args:
34+
num_vectors (int): the number of hypervectors to generate.
35+
dimensions (int): the dimensionality of the hypervectors.
36+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` is ``torch.get_default_dtype()``.
37+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
38+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
39+
40+
Examples::
41+
42+
>>> torchhd.HRR.empty_hv(3, 6)
43+
HRR([[0., 0., 0., 0., 0., 0.],
44+
[0., 0., 0., 0., 0., 0.],
45+
[0., 0., 0., 0., 0., 0.]])
46+
>>> torchhd.HRR.empty_hv(3, 6, dtype=torch.float64)
47+
HRR([[0., 0., 0., 0., 0., 0.],
48+
[0., 0., 0., 0., 0., 0.],
49+
[0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
50+
51+
"""
2352

2453
if dtype is None:
2554
dtype = torch.get_default_dtype()
@@ -48,6 +77,29 @@ def identity_hv(
4877
device=None,
4978
requires_grad=False,
5079
) -> "HRR":
80+
"""Creates a set of identity hypervectors.
81+
82+
When bound with a random-hypervector :math:`x`, the result is :math:`x`.
83+
84+
Args:
85+
num_vectors (int): the number of hypervectors to generate.
86+
dimensions (int): the dimensionality of the hypervectors.
87+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` is ``torch.get_default_dtype()``.
88+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
89+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
90+
91+
Examples::
92+
93+
>>> torchhd.HRR.identity_hv(3, 6)
94+
HRR([[1., 0., 0., 0., 0., 0.],
95+
[1., 0., 0., 0., 0., 0.],
96+
[1., 0., 0., 0., 0., 0.]])
97+
>>> torchhd.HRR.identity_hv(3, 6, dtype=torch.float64)
98+
HRR([[1., 0., 0., 0., 0., 0.],
99+
[1., 0., 0., 0., 0., 0.],
100+
[1., 0., 0., 0., 0., 0.]], dtype=torch.float64)
101+
102+
"""
51103

52104
if dtype is None:
53105
dtype = torch.get_default_dtype()
@@ -57,13 +109,13 @@ def identity_hv(
57109
options = ", ".join([str(x) for x in cls.supported_dtypes])
58110
raise ValueError(f"{name} vectors must be one of dtype {options}.")
59111

60-
result = torch.ones(
112+
result = torch.zeros(
61113
num_vectors,
62114
dimensions,
63115
dtype=dtype,
64116
device=device,
65117
)
66-
result = torch.real(ifft(result))
118+
result[:, 0] = 1
67119
result.requires_grad = requires_grad
68120
return result.as_subclass(cls)
69121

@@ -78,6 +130,29 @@ def random_hv(
78130
device=None,
79131
requires_grad=False,
80132
) -> "HRR":
133+
"""Creates a set of random independent hypervectors.
134+
135+
The resulting hypervectors are sampled at random from a normal with mean 0 and standard deviation 1/dimensions.
136+
137+
Args:
138+
num_vectors (int): the number of hypervectors to generate.
139+
dimensions (int): the dimensionality of the hypervectors.
140+
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
141+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` is ``torch.get_default_dtype()``.
142+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
143+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
144+
145+
Examples::
146+
147+
>>> torchhd.HRR.random_hv(3, 6)
148+
HRR([[ 0.2520, -0.0048, -0.0351, 0.2067, 0.0638, -0.0729],
149+
[-0.2695, 0.0815, 0.0103, 0.2211, -0.1202, 0.2134],
150+
[ 0.0086, -0.1748, -0.1715, 0.3215, -0.1353, 0.0044]])
151+
>>> torchhd.HRR.random_hv(3, 6, dtype=torch.float64)
152+
HRR([[-0.1327, -0.0396, -0.0065, 0.0886, -0.4665, 0.2656],
153+
[-0.2879, -0.1070, -0.0851, -0.4366, -0.1311, 0.3976],
154+
[-0.0472, 0.2987, -0.1567, 0.1496, -0.0098, 0.0344]], dtype=torch.float64)
155+
"""
81156

82157
if dtype is None:
83158
dtype = torch.get_default_dtype()
@@ -91,40 +166,183 @@ def random_hv(
91166
result = torch.empty(size, dtype=dtype, device=device)
92167
result.normal_(0, 1.0 / dimensions, generator=generator)
93168

94-
# projection
95-
# f = torch.abs(fft(result))
96-
# p = ifft(fft(result) / f).real
97-
# result = torch.nan_to_num(p)
98169
result.requires_grad = requires_grad
99170
return result.as_subclass(cls)
100171

101172
def bundle(self, other: "HRR") -> "HRR":
173+
r"""Bundle the hypervector with other using element-wise sum.
174+
175+
This produces a hypervector maximally similar to both.
176+
177+
The bundling operation is used to aggregate information into a single hypervector.
178+
179+
Args:
180+
other (HRR): other input hypervector
181+
182+
Shapes:
183+
- Self: :math:`(*)`
184+
- Other: :math:`(*)`
185+
- Output: :math:`(*)`
186+
187+
Examples::
188+
189+
>>> a, b = torchhd.HRR.random_hv(2, 6)
190+
>>> a
191+
HRR([ 0.1916, -0.1451, -0.0678, 0.0829, 0.3816, -0.0906])
192+
>>> b
193+
HRR([-0.2825, 0.3788, 0.0885, -0.1269, -0.0481, -0.3029])
194+
>>> a.bundle(b)
195+
HRR([-0.0909, 0.2336, 0.0207, -0.0440, 0.3336, -0.3935])
196+
197+
>>> a, b = torchhd.HRR.random_hv(2, 6, dtype=torch.float64)
198+
>>> a
199+
HRR([ 0.3879, -0.0452, -0.0082, -0.2262, -0.2764, 0.0166], dtype=torch.float64)
200+
>>> b
201+
HRR([ 0.0738, -0.0306, 0.4948, 0.1209, 0.1482, 0.1268], dtype=torch.float64)
202+
>>> a.bundle(b)
203+
HRR([ 0.4618, -0.0758, 0.4866, -0.1053, -0.1281, 0.1434], dtype=torch.float64)
204+
205+
"""
102206
return self.add(other)
103207

104208
def multibundle(self) -> "HRR":
209+
"""Bundle multiple hypervectors"""
105210
return self.sum(dim=-2, dtype=self.dtype)
106211

107212
def bind(self, other: "HRR") -> "HRR":
213+
r"""Bind the hypervector with other using circular convolution.
214+
215+
This produces a hypervector dissimilar to both.
216+
217+
Binding is used to associate information, for instance, to assign values to variables.
218+
219+
Args:
220+
other (HRR): other input hypervector
221+
222+
Shapes:
223+
- Self: :math:`(*)`
224+
- Other: :math:`(*)`
225+
- Output: :math:`(*)`
226+
227+
Examples::
228+
229+
>>> a, b = torchhd.HRR.random_hv(2, 6)
230+
>>> a
231+
HRR([ 0.0101, -0.2474, -0.0097, -0.0788, 0.1541, -0.1766])
232+
>>> b
233+
HRR([-0.0338, 0.0340, 0.0289, -0.1498, 0.1178, -0.2822])
234+
>>> a.bind(b)
235+
HRR([ 0.0786, -0.0260, 0.0591, -0.0706, 0.0799, -0.0216])
236+
237+
>>> a, b = torchhd.HRR.random_hv(2, 6, dtype=torch.float64)
238+
>>> a
239+
HRR([ 0.0354, -0.0818, 0.0216, 0.0384, 0.2961, 0.1976], dtype=torch.float64)
240+
>>> b
241+
HRR([ 0.3640, -0.0640, -0.1033, -0.1454, 0.0999, 0.0299], dtype=torch.float64)
242+
>>> a.bind(b)
243+
HRR([-0.0362, -0.0910, 0.0114, 0.0445, 0.1244, 0.0388], dtype=torch.float64)
244+
245+
"""
108246
result = ifft(torch.mul(fft(self), fft(other)))
109-
return result.real
247+
return torch.real(result)
110248

111249
def multibind(self) -> "HRR":
250+
"""Bind multiple hypervectors"""
112251
result = ifft(torch.prod(fft(self), dim=-2, dtype=self.dtype))
113-
return result.real
252+
return torch.real(result)
253+
254+
def exact_inverse(self) -> "HRR":
255+
"""Unstable, but exact, inverse"""
256+
result = ifft(1.0 / fft(self).conj())
257+
result = torch.real(result)
258+
return torch.nan_to_num(result)
114259

115260
def inverse(self) -> "HRR":
116-
return self.flip(dims=(-1,)).roll(1, dims=-1)
261+
r"""Stable inversion of the hypervector for binding.
262+
263+
For HRR the stable inverse of hypervector is its conjugate in the frequency domain, this returns the conjugate of the hypervector.
264+
265+
Shapes:
266+
- Self: :math:`(*)`
267+
- Output: :math:`(*)`
268+
269+
Examples::
270+
271+
>>> a = torchhd.HRR.random_hv(1, 6)
272+
>>> a
273+
HRR([[ 0.1406, 0.0014, -0.0502, 0.2888, 0.2969, -0.2637]])
274+
>>> a.inverse()
275+
HRR([[ 0.1406, -0.2637, 0.2969, 0.2888, -0.0502, 0.0014]])
276+
277+
>>> a = torchhd.HRR.random_hv(1, 6, dtype=torch.float64)
278+
>>> a
279+
HRR([[ 0.0090, 0.2620, 0.0836, 0.0441, -0.2351, -0.1744]], dtype=torch.float64)
280+
>>> a.inverse()
281+
HRR([[ 0.0090, -0.1744, -0.2351, 0.0441, 0.0836, 0.2620]], dtype=torch.float64)
282+
283+
"""
284+
result = ifft(fft(self).conj())
285+
return torch.real(result)
117286

118287
def negative(self) -> "HRR":
288+
r"""Negate the hypervector for the bundling inverse.
289+
290+
Shapes:
291+
- Self: :math:`(*)`
292+
- Output: :math:`(*)`
293+
294+
Examples::
295+
296+
>>> a = torchhd.HRR.random_hv(1, 6)
297+
>>> a
298+
HRR([[ 0.2658, -0.2808, 0.1436, 0.1131, 0.1567, -0.1426]])
299+
>>> a.negative()
300+
HRR([[-0.2658, 0.2808, -0.1436, -0.1131, -0.1567, 0.1426]])
301+
302+
>>> a = torchhd.HRR.random_hv(1, 6, dtype=torch.float64)
303+
>>> a
304+
HRR([[ 0.0318, 0.1944, 0.1229, 0.0193, 0.0135, -0.2521]], dtype=torch.float64)
305+
>>> a.negative()
306+
HRR([[-0.0318, -0.1944, -0.1229, -0.0193, -0.0135, 0.2521]], dtype=torch.float64)
307+
308+
"""
119309
return torch.negative(self)
120310

121311
def permute(self, shifts: int = 1) -> "HRR":
312+
r"""Permute the hypervector.
313+
314+
The permutation operator is used to assign an order to hypervectors.
315+
316+
Args:
317+
shifts (int, optional): The number of places by which the elements of the tensor are shifted.
318+
319+
Shapes:
320+
- Self: :math:`(*)`
321+
- Output: :math:`(*)`
322+
323+
Examples::
324+
325+
>>> a = torchhd.HRR.random_hv(1, 6)
326+
>>> a
327+
HRR([[-0.2521, 0.1140, -0.1647, -0.1490, -0.2091, -0.0618]])
328+
>>> a.permute()
329+
HRR([[-0.0618, -0.2521, 0.1140, -0.1647, -0.1490, -0.2091]])
330+
331+
>>> a = torchhd.HRR.random_hv(1, 6, dtype=torch.float64)
332+
>>> a
333+
HRR([[-0.0495, -0.0318, 0.3923, -0.3205, 0.1587, 0.1926]], dtype=torch.float64)
334+
>>> a.permute()
335+
HRR([[ 0.1926, -0.0495, -0.0318, 0.3923, -0.3205, 0.1587]], dtype=torch.float64)
336+
337+
"""
122338
return self.roll(shifts=shifts, dims=-1)
123339

124340
def dot_similarity(self, others: "HRR") -> Tensor:
341+
"""Inner product with other hypervectors"""
125342
return F.linear(self, others)
126343

127344
def cos_similarity(self, others: "HRR", *, eps=1e-08) -> Tensor:
345+
"""Cosine similarity with other hypervectors"""
128346
self_dot = torch.sum(self * self, dim=-1)
129347
self_mag = self_dot.sqrt()
130348

0 commit comments

Comments
 (0)