Skip to content

Commit b979599

Browse files
authored
Add randsel bundling of hypervectors (#86)
* Add randsel bundling implementation * Add randsel to documentation * Add randsel testing
1 parent 81160a5 commit b979599

File tree

4 files changed

+181
-18
lines changed

4 files changed

+181
-18
lines changed

docs/functional.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Operations
3232
bundle
3333
permute
3434
cleanup
35+
randsel
36+
multirandsel
3537
soft_quantize
3638
hard_quantize
3739

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
unbind,
1414
bundle,
1515
permute,
16+
randsel,
1617
cosine_similarity,
1718
dot_similarity,
1819
)
@@ -33,6 +34,7 @@
3334
"unbind",
3435
"bundle",
3536
"permute",
37+
"randsel",
3638
"cosine_similarity",
3739
"dot_similarity",
3840
]

torchhd/functional.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import torch
3-
from torch import BoolTensor, LongTensor, Tensor
3+
from torch import BoolTensor, LongTensor, FloatTensor, Tensor
44
import torch.nn.functional as F
55
from collections import deque
66

@@ -482,12 +482,13 @@ def bind(input: Tensor, other: Tensor) -> Tensor:
482482
483483
Examples::
484484
485-
>>> x = functional.random_hv(2, 3)
486-
>>> x
487-
tensor([[ 1., -1., -1.],
488-
[ 1., 1., 1.]])
489-
>>> functional.bind(x[0], x[1])
490-
tensor([ 1., -1., -1.])
485+
>>> a, b = functional.random_hv(2, 10)
486+
>>> a
487+
tensor([-1., 1., -1., -1., 1., 1., -1., 1., 1., -1.])
488+
>>> b
489+
tensor([-1., -1., 1., 1., -1., -1., 1., 1., 1., 1.])
490+
>>> functional.bind(a, b)
491+
tensor([ 1., -1., -1., -1., -1., -1., -1., 1., 1., -1.])
491492
492493
"""
493494
dtype = input.dtype
@@ -571,12 +572,13 @@ def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor:
571572
572573
Examples::
573574
574-
>>> x = functional.random_hv(2, 3)
575-
>>> x
576-
tensor([[ 1., 1., 1.],
577-
[-1., 1., -1.]])
578-
>>> functional.bundle(x[0], x[1])
579-
tensor([0., 2., 0.])
575+
>>> a, b = functional.random_hv(2, 10)
576+
>>> a
577+
tensor([ 1., -1., 1., -1., -1., 1., -1., -1., 1., -1.])
578+
>>> b
579+
tensor([ 1., -1., -1., 1., 1., 1., -1., 1., -1., 1.])
580+
>>> functional.bundle(a, b)
581+
tensor([ 2., -2., 0., 0., 0., 2., -2., 0., 0., 0.])
580582
581583
"""
582584
dtype = input.dtype
@@ -891,6 +893,91 @@ def multiset(input: Tensor) -> Tensor:
891893

892894
return torch.sum(input, dim=dim, dtype=dtype)
893895

896+
def randsel(
897+
input: Tensor, other: Tensor, *, p: float = 0.5, generator: torch.Generator = None
898+
) -> Tensor:
899+
r"""Bundles two hypervectors by selecting random elements.
900+
901+
A bundling operation is used to aggregate information into a single hypervector.
902+
The resulting hypervector has elements selected at random from input or other.
903+
904+
.. math::
905+
906+
\oplus: \mathcal{H} \times \mathcal{H} \to \mathcal{H}
907+
908+
Aliased as ``torchhd.randsel``.
909+
910+
Args:
911+
input (Tensor): input hypervector
912+
other (Tensor): other input hypervector
913+
p (float, optional): probability of selecting elements from the input hypervector. Default: 0.5.
914+
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
915+
916+
Shapes:
917+
- Input: :math:`(*)`
918+
- Other: :math:`(*)`
919+
- Output: :math:`(*)`
920+
921+
Examples::
922+
923+
>>> a, b = functional.random_hv(2, 10)
924+
>>> a
925+
tensor([ 1., -1., 1., -1., 1., -1., -1., -1., -1., 1.])
926+
>>> b
927+
tensor([ 1., -1., 1., -1., 1., 1., -1., 1., -1., 1.])
928+
>>> functional.randsel(a, b)
929+
tensor([ 1., -1., 1., -1., 1., 1., -1., 1., -1., 1.])
930+
931+
"""
932+
select = torch.empty_like(input, dtype=torch.bool)
933+
select.bernoulli_(1 - p, generator=generator)
934+
return input.where(select, other)
935+
936+
937+
def multirandsel(
938+
input: Tensor, *, p: FloatTensor = None, generator: torch.Generator = None
939+
) -> Tensor:
940+
r"""Bundling multiple hypervectors by sampling random elements.
941+
942+
Bundles all the input hypervectors together.
943+
The resulting hypervector has elements selected at random from the input tensor of hypervectors.
944+
945+
.. math::
946+
947+
\bigoplus_{i=0}^{n-1} V_i
948+
949+
Args:
950+
input (Tensor): input hypervector tensor
951+
p (FloatTensor, optional): probability of selecting elements from the input hypervector. Default: uniform.
952+
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
953+
954+
Shapes:
955+
- Input: :math:`(*, n, d)`
956+
- Probability (p): :math:`(*, n)`
957+
- Output: :math:`(*, d)`
958+
959+
Examples::
960+
961+
>>> x = functional.random_hv(5, 10)
962+
>>> x
963+
tensor([[-1., 1., 1., 1., -1., 1., 1., -1., 1., -1.],
964+
[-1., 1., 1., 1., -1., 1., -1., -1., -1., -1.],
965+
[-1., -1., 1., -1., -1., 1., -1., 1., 1., -1.],
966+
[ 1., 1., -1., 1., -1., -1., 1., -1., 1., 1.],
967+
[ 1., -1., -1., 1., 1., 1., 1., -1., -1., -1.]])
968+
>>> functional.multirandsel(x)
969+
tensor([ 1., -1., -1., 1., -1., 1., 1., 1., 1., -1.])
970+
971+
"""
972+
d = input.size(-1)
973+
device = input.device
974+
975+
if p is None:
976+
p = torch.ones(input.shape[:-1], dtype=torch.float, device=device)
977+
978+
select = torch.multinomial(p, d, replacement=True, generator=generator)
979+
select.unsqueeze_(-2)
980+
return input.gather(-2, select).squeeze(-2)
894981

895982
multibundle = multiset
896983

@@ -917,13 +1004,15 @@ def multibind(input: Tensor) -> Tensor:
9171004
9181005
Examples::
9191006
920-
>>> x = functional.random_hv(3, 3)
1007+
>>> x = functional.random_hv(5, 10)
9211008
>>> x
922-
tensor([[ 1., 1., 1.],
923-
[-1., 1., 1.],
924-
[-1., 1., -1.]])
1009+
tensor([[ 1., -1., -1., -1., -1., 1., 1., 1., -1., 1.],
1010+
[ 1., 1., -1., -1., 1., 1., -1., -1., 1., -1.],
1011+
[-1., -1., 1., -1., -1., -1., -1., 1., -1., -1.],
1012+
[-1., 1., -1., 1., 1., -1., -1., -1., 1., 1.],
1013+
[-1., -1., 1., -1., 1., -1., 1., 1., -1., 1.]])
9251014
>>> functional.multibind(x)
926-
tensor([ 1., 1., -1.])
1015+
tensor([-1., -1., -1., 1., 1., -1., -1., 1., -1., 1.])
9271016
9281017
"""
9291018
dtype = input.dtype

torchhd/tests/test_operations.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,73 @@ def test_device(self):
256256
hv = functional.random_hv(5, 100, device=device)
257257
res = functional.cleanup(hv[0], hv)
258258
assert res.device == device
259+
260+
261+
class TestRandsel:
262+
def test_value(self):
263+
generator = torch.Generator()
264+
generator.manual_seed(2147483644)
265+
266+
a, b = functional.random_hv(2, 1000, generator=generator)
267+
res = functional.randsel(a, b, p=0, generator=generator)
268+
assert torch.all(a == res)
269+
270+
a, b = functional.random_hv(2, 1000, generator=generator)
271+
res = functional.randsel(a, b, p=1, generator=generator)
272+
assert torch.all(b == res)
273+
274+
a, b = functional.random_hv(2, 1000, generator=generator)
275+
res = functional.randsel(a, b, generator=generator)
276+
assert torch.all((b == res) | (a == res))
277+
278+
@pytest.mark.parametrize("dtype", torch_dtypes)
279+
def test_dtype(self, dtype):
280+
a, b = torch.zeros(2, 1000, dtype=dtype)
281+
res = functional.randsel(a, b)
282+
assert res.dtype == dtype
283+
284+
def test_device(self):
285+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
286+
287+
a, b = functional.random_hv(2, 100, device=device)
288+
res = functional.randsel(a, b)
289+
290+
assert res.dtype == a.dtype
291+
assert res.dim() == 1
292+
assert res.size(0) == 100
293+
assert res.device == device
294+
295+
296+
class TestMultiRandsel:
297+
def test_value(self):
298+
generator = torch.Generator()
299+
generator.manual_seed(2147483644)
300+
301+
x = functional.random_hv(4, 1000, generator=generator)
302+
res = functional.multirandsel(x, p=torch.tensor([0.0,0.0,1.0,0.0]), generator=generator)
303+
assert torch.all(x[2] == res)
304+
305+
x = functional.random_hv(4, 1000, generator=generator)
306+
res = functional.multirandsel(x, p=torch.tensor([0.5,0.0,0.5,0.0]), generator=generator)
307+
assert torch.all((x[0] == res) | (x[2] == res))
308+
309+
x = functional.random_hv(4, 1000, generator=generator)
310+
res = functional.multirandsel(x, generator=generator)
311+
assert torch.all((x[0] == res) | (x[1] == res) | (x[2] == res) | (x[3] == res))
312+
313+
@pytest.mark.parametrize("dtype", torch_dtypes)
314+
def test_dtype(self, dtype):
315+
x = torch.zeros(4, 1000, dtype=dtype)
316+
res = functional.multirandsel(x)
317+
assert res.dtype == dtype
318+
319+
def test_device(self):
320+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
321+
322+
x = functional.random_hv(4, 100, device=device)
323+
res = functional.multirandsel(x)
324+
325+
assert res.dtype == x.dtype
326+
assert res.dim() == 1
327+
assert res.size(0) == 100
328+
assert res.device == device

0 commit comments

Comments
 (0)