Skip to content
147 changes: 90 additions & 57 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from torch import LongTensor, Tensor
from torch import BoolTensor, LongTensor, Tensor
import torch.nn.functional as F

from collections import deque
Expand Down Expand Up @@ -64,14 +64,21 @@ def identity_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

if dtype == torch.bool:
return torch.zeros(
num_embeddings,
embedding_dim,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)

return torch.ones(
num_embeddings,
embedding_dim,
Expand Down Expand Up @@ -122,10 +129,8 @@ def random_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand All @@ -137,6 +142,11 @@ def random_hv(
),
dtype=torch.bool,
).bernoulli_(1.0 - sparsity, generator=generator)

if dtype == torch.bool:
select.requires_grad = requires_grad
return select

result = torch.where(select, -1, +1).to(dtype=dtype, device=device)
result.requires_grad = requires_grad
return result
Expand Down Expand Up @@ -183,10 +193,8 @@ def level_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand All @@ -200,6 +208,8 @@ def level_hv(

# convert from normalized "randomness" variable r to number of orthogonal vectors sets "span"
levels_per_span = (1 - randomness) * (num_embeddings - 1) + randomness * 1
# must be at least one to deal with the case that num_embeddings is less than 2
levels_per_span = max(levels_per_span, 1)
span = (num_embeddings - 1) / levels_per_span
# generate the set of orthogonal vectors within the level vector set
span_hv = random_hv(
Expand Down Expand Up @@ -287,10 +297,8 @@ def circular_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand Down Expand Up @@ -354,15 +362,15 @@ def circular_hv(

temp_hv = torch.where(threshold_v[span_idx] < t, span_start_hv, span_end_hv)

mutation_history.append(temp_hv * mutation_hv)
mutation_history.append(bind(temp_hv, mutation_hv))
mutation_hv = temp_hv

if i % 2 == 0:
hv[i // 2] = mutation_hv

for i in range(num_embeddings + 1, num_embeddings * 2 - 1):
mut = mutation_history.popleft()
mutation_hv *= mut
mutation_hv = bind(mutation_hv, mut)

if i % 2 == 0:
hv[i // 2] = mutation_hv
Expand All @@ -371,7 +379,7 @@ def circular_hv(
return hv


def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
def bind(input: Tensor, other: Tensor) -> Tensor:
r"""Binds two hypervectors which produces a hypervector dissimilar to both.

Binding is used to associate information, for instance, to assign values to variables.
Expand All @@ -385,7 +393,6 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
Args:
input (Tensor): input hypervector
other (Tensor): other input hypervector
out (Tensor, optional): the output tensor.

Shapes:
- Input: :math:`(*)`
Expand All @@ -402,18 +409,21 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
tensor([ 1., -1., -1.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
dtype = input.dtype

if input.dtype == torch.uint8:
if torch.is_complex(input):
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.mul(input, other, out=out)
if dtype == torch.bool:
return torch.logical_xor(input, other)

return torch.mul(input, other)

def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:

def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor:
r"""Bundles two hypervectors which produces a hypervector maximally similar to both.

The bundling operation is used to aggregate information into a single hypervector.
Expand All @@ -427,7 +437,7 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
Args:
input (Tensor): input hypervector
other (Tensor): other input hypervector
out (Tensor, optional): the output tensor.
tie (BoolTensor, optional): specifies how to break a tie while bundling boolean hypervectors. Default: only set bit if both ``input`` and ``other`` are ``True``.

Shapes:
- Input: :math:`(*)`
Expand All @@ -444,15 +454,21 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
tensor([0., 2., 0.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
dtype = input.dtype

if input.dtype == torch.uint8:
if torch.is_complex(input):
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.add(input, other, out=out)
if dtype == torch.bool:
if tie is not None:
return torch.where(input == other, input, tie)
else:
return torch.logical_and(input, other)

return torch.add(input, other)


def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
Expand Down Expand Up @@ -484,15 +500,19 @@ def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
tensor([ -1., 1., -1.])

"""
dtype = input.dtype

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.roll(input, shifts=shifts, dims=dims)


def soft_quantize(input: Tensor, *, out=None):
def soft_quantize(input: Tensor):
"""Applies the hyperbolic tanh function to all elements of the input tensor.

Args:
input (Tensor): input tensor.
out (Tensor, optional): output tensor. Defaults to None.

Shapes:
- Input: :math:`(*)`
Expand All @@ -508,15 +528,14 @@ def soft_quantize(input: Tensor, *, out=None):
tensor([0.0000, 0.9640, 0.0000])

"""
return torch.tanh(input, out=out)
return torch.tanh(input)


def hard_quantize(input: Tensor, *, out=None):
def hard_quantize(input: Tensor):
"""Applies binary quantization to all elements of the input tensor.

Args:
input (Tensor): input tensor
out (Tensor, optional): output tensor. Defaults to None.

Shapes:
- Input: :math:`(*)`
Expand All @@ -537,13 +556,7 @@ def hard_quantize(input: Tensor, *, out=None):
positive = torch.tensor(1.0, dtype=input.dtype, device=input.device)
negative = torch.tensor(-1.0, dtype=input.dtype, device=input.device)

if out != None:
out[:] = torch.where(input > 0, positive, negative)
result = out
else:
result = torch.where(input > 0, positive, negative)

return result
return torch.where(input > 0, positive, negative)


def cosine_similarity(input: Tensor, others: Tensor) -> Tensor:
Expand Down Expand Up @@ -650,16 +663,21 @@ def multiset(input: Tensor) -> Tensor:
tensor([-1., 3., 1.])

"""
dim = -2
dtype = input.dtype

if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if input.dtype == torch.uint8:
if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.sum(input, dim=-2, dtype=input.dtype)
if dtype == torch.bool:
count = torch.sum(input, dim=dim, dtype=torch.long)
threshold = input.size(dim) // 2
return torch.greater(count, threshold)

return torch.sum(input, dim=dim, dtype=dtype)


multibundle = multiset
Expand All @@ -681,6 +699,10 @@ def multibind(input: Tensor) -> Tensor:
- Input: :math:`(*, n, d)`
- Output: :math:`(*, d)`

.. note::

This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.

Examples::

>>> x = functional.random_hv(3, 3)
Expand All @@ -692,14 +714,21 @@ def multibind(input: Tensor) -> Tensor:
tensor([ 1., 1., -1.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if input.dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if input.dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

if input.dtype == torch.bool:
hvs = torch.unbind(input, -2)
result = hvs[0]

for i in range(1, len(hvs)):
result = torch.logical_xor(result, hvs[i])

return result

return torch.prod(input, dim=-2, dtype=input.dtype)


Expand Down Expand Up @@ -870,6 +899,10 @@ def bind_sequence(input: Tensor) -> Tensor:
- Input: :math:`(*, n, d)`
- Output: :math:`(*, d)`

.. note::

This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.

Examples::

>>> x = functional.random_hv(5, 3)
Expand Down
Empty file.
Loading