Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Operations
bundle
permute
cleanup
randsel
multirandsel
soft_quantize
hard_quantize

Expand Down
2 changes: 2 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
unbind,
bundle,
permute,
randsel,
cosine_similarity,
dot_similarity,
)
Expand All @@ -33,6 +34,7 @@
"unbind",
"bundle",
"permute",
"randsel",
"cosine_similarity",
"dot_similarity",
]
125 changes: 107 additions & 18 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from torch import BoolTensor, LongTensor, Tensor
from torch import BoolTensor, LongTensor, FloatTensor, Tensor
import torch.nn.functional as F
from collections import deque

Expand Down Expand Up @@ -482,12 +482,13 @@ def bind(input: Tensor, other: Tensor) -> Tensor:

Examples::

>>> x = functional.random_hv(2, 3)
>>> x
tensor([[ 1., -1., -1.],
[ 1., 1., 1.]])
>>> functional.bind(x[0], x[1])
tensor([ 1., -1., -1.])
>>> a, b = functional.random_hv(2, 10)
>>> a
tensor([-1., 1., -1., -1., 1., 1., -1., 1., 1., -1.])
>>> b
tensor([-1., -1., 1., 1., -1., -1., 1., 1., 1., 1.])
>>> functional.bind(a, b)
tensor([ 1., -1., -1., -1., -1., -1., -1., 1., 1., -1.])

"""
dtype = input.dtype
Expand Down Expand Up @@ -571,12 +572,13 @@ def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor:

Examples::

>>> x = functional.random_hv(2, 3)
>>> x
tensor([[ 1., 1., 1.],
[-1., 1., -1.]])
>>> functional.bundle(x[0], x[1])
tensor([0., 2., 0.])
>>> a, b = functional.random_hv(2, 10)
>>> a
tensor([ 1., -1., 1., -1., -1., 1., -1., -1., 1., -1.])
>>> b
tensor([ 1., -1., -1., 1., 1., 1., -1., 1., -1., 1.])
>>> functional.bundle(a, b)
tensor([ 2., -2., 0., 0., 0., 2., -2., 0., 0., 0.])

"""
dtype = input.dtype
Expand Down Expand Up @@ -891,6 +893,91 @@ def multiset(input: Tensor) -> Tensor:

return torch.sum(input, dim=dim, dtype=dtype)

def randsel(
input: Tensor, other: Tensor, *, p: float = 0.5, generator: torch.Generator = None
) -> Tensor:
r"""Bundles two hypervectors by selecting random elements.

A bundling operation is used to aggregate information into a single hypervector.
The resulting hypervector has elements selected at random from input or other.

.. math::

\oplus: \mathcal{H} \times \mathcal{H} \to \mathcal{H}

Aliased as ``torchhd.randsel``.

Args:
input (Tensor): input hypervector
other (Tensor): other input hypervector
p (float, optional): probability of selecting elements from the input hypervector. Default: 0.5.
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.

Shapes:
- Input: :math:`(*)`
- Other: :math:`(*)`
- Output: :math:`(*)`

Examples::

>>> a, b = functional.random_hv(2, 10)
>>> a
tensor([ 1., -1., 1., -1., 1., -1., -1., -1., -1., 1.])
>>> b
tensor([ 1., -1., 1., -1., 1., 1., -1., 1., -1., 1.])
>>> functional.randsel(a, b)
tensor([ 1., -1., 1., -1., 1., 1., -1., 1., -1., 1.])

"""
select = torch.empty_like(input, dtype=torch.bool)
select.bernoulli_(1 - p, generator=generator)
return input.where(select, other)


def multirandsel(
input: Tensor, *, p: FloatTensor = None, generator: torch.Generator = None
) -> Tensor:
r"""Bundling multiple hypervectors by sampling random elements.

Bundles all the input hypervectors together.
The resulting hypervector has elements selected at random from the input tensor of hypervectors.

.. math::

\bigoplus_{i=0}^{n-1} V_i

Args:
input (Tensor): input hypervector tensor
p (FloatTensor, optional): probability of selecting elements from the input hypervector. Default: uniform.
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.

Shapes:
- Input: :math:`(*, n, d)`
- Probability (p): :math:`(*, n)`
- Output: :math:`(*, d)`

Examples::

>>> x = functional.random_hv(5, 10)
>>> x
tensor([[-1., 1., 1., 1., -1., 1., 1., -1., 1., -1.],
[-1., 1., 1., 1., -1., 1., -1., -1., -1., -1.],
[-1., -1., 1., -1., -1., 1., -1., 1., 1., -1.],
[ 1., 1., -1., 1., -1., -1., 1., -1., 1., 1.],
[ 1., -1., -1., 1., 1., 1., 1., -1., -1., -1.]])
>>> functional.multirandsel(x)
tensor([ 1., -1., -1., 1., -1., 1., 1., 1., 1., -1.])

"""
d = input.size(-1)
device = input.device

if p is None:
p = torch.ones(input.shape[:-1], dtype=torch.float, device=device)

select = torch.multinomial(p, d, replacement=True, generator=generator)
select.unsqueeze_(-2)
return input.gather(-2, select).squeeze(-2)

multibundle = multiset

Expand All @@ -917,13 +1004,15 @@ def multibind(input: Tensor) -> Tensor:

Examples::

>>> x = functional.random_hv(3, 3)
>>> x = functional.random_hv(5, 10)
>>> x
tensor([[ 1., 1., 1.],
[-1., 1., 1.],
[-1., 1., -1.]])
tensor([[ 1., -1., -1., -1., -1., 1., 1., 1., -1., 1.],
[ 1., 1., -1., -1., 1., 1., -1., -1., 1., -1.],
[-1., -1., 1., -1., -1., -1., -1., 1., -1., -1.],
[-1., 1., -1., 1., 1., -1., -1., -1., 1., 1.],
[-1., -1., 1., -1., 1., -1., 1., 1., -1., 1.]])
>>> functional.multibind(x)
tensor([ 1., 1., -1.])
tensor([-1., -1., -1., 1., 1., -1., -1., 1., -1., 1.])

"""
dtype = input.dtype
Expand Down
70 changes: 70 additions & 0 deletions torchhd/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,73 @@ def test_device(self):
hv = functional.random_hv(5, 100, device=device)
res = functional.cleanup(hv[0], hv)
assert res.device == device


class TestRandsel:
def test_value(self):
generator = torch.Generator()
generator.manual_seed(2147483644)

a, b = functional.random_hv(2, 1000, generator=generator)
res = functional.randsel(a, b, p=0, generator=generator)
assert torch.all(a == res)

a, b = functional.random_hv(2, 1000, generator=generator)
res = functional.randsel(a, b, p=1, generator=generator)
assert torch.all(b == res)

a, b = functional.random_hv(2, 1000, generator=generator)
res = functional.randsel(a, b, generator=generator)
assert torch.all((b == res) | (a == res))

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
a, b = torch.zeros(2, 1000, dtype=dtype)
res = functional.randsel(a, b)
assert res.dtype == dtype

def test_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

a, b = functional.random_hv(2, 100, device=device)
res = functional.randsel(a, b)

assert res.dtype == a.dtype
assert res.dim() == 1
assert res.size(0) == 100
assert res.device == device


class TestMultiRandsel:
def test_value(self):
generator = torch.Generator()
generator.manual_seed(2147483644)

x = functional.random_hv(4, 1000, generator=generator)
res = functional.multirandsel(x, p=torch.tensor([0.0,0.0,1.0,0.0]), generator=generator)
assert torch.all(x[2] == res)

x = functional.random_hv(4, 1000, generator=generator)
res = functional.multirandsel(x, p=torch.tensor([0.5,0.0,0.5,0.0]), generator=generator)
assert torch.all((x[0] == res) | (x[2] == res))

x = functional.random_hv(4, 1000, generator=generator)
res = functional.multirandsel(x, generator=generator)
assert torch.all((x[0] == res) | (x[1] == res) | (x[2] == res) | (x[3] == res))

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
x = torch.zeros(4, 1000, dtype=dtype)
res = functional.multirandsel(x)
assert res.dtype == dtype

def test_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = functional.random_hv(4, 100, device=device)
res = functional.multirandsel(x)

assert res.dtype == x.dtype
assert res.dim() == 1
assert res.size(0) == 100
assert res.device == device