Skip to content

Add binary hypervector support #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 1, 2022
147 changes: 90 additions & 57 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from torch import LongTensor, Tensor
from torch import BoolTensor, LongTensor, Tensor
import torch.nn.functional as F

from collections import deque
Expand Down Expand Up @@ -64,14 +64,21 @@ def identity_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

if dtype == torch.bool:
return torch.zeros(
num_embeddings,
embedding_dim,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)

return torch.ones(
num_embeddings,
embedding_dim,
Expand Down Expand Up @@ -122,10 +129,8 @@ def random_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand All @@ -137,6 +142,11 @@ def random_hv(
),
dtype=torch.bool,
).bernoulli_(1.0 - sparsity, generator=generator)

if dtype == torch.bool:
select.requires_grad = requires_grad
return select

result = torch.where(select, -1, +1).to(dtype=dtype, device=device)
result.requires_grad = requires_grad
return result
Expand Down Expand Up @@ -183,10 +193,8 @@ def level_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand All @@ -200,6 +208,8 @@ def level_hv(

# convert from normalized "randomness" variable r to number of orthogonal vectors sets "span"
levels_per_span = (1 - randomness) * (num_embeddings - 1) + randomness * 1
# must be at least one to deal with the case that num_embeddings is less than 2
levels_per_span = max(levels_per_span, 1)
span = (num_embeddings - 1) / levels_per_span
# generate the set of orthogonal vectors within the level vector set
span_hv = random_hv(
Expand Down Expand Up @@ -287,10 +297,8 @@ def circular_hv(
if dtype is None:
dtype = torch.get_default_dtype()

if dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")
Expand Down Expand Up @@ -354,15 +362,15 @@ def circular_hv(

temp_hv = torch.where(threshold_v[span_idx] < t, span_start_hv, span_end_hv)

mutation_history.append(temp_hv * mutation_hv)
mutation_history.append(bind(temp_hv, mutation_hv))
mutation_hv = temp_hv

if i % 2 == 0:
hv[i // 2] = mutation_hv

for i in range(num_embeddings + 1, num_embeddings * 2 - 1):
mut = mutation_history.popleft()
mutation_hv *= mut
mutation_hv = bind(mutation_hv, mut)

if i % 2 == 0:
hv[i // 2] = mutation_hv
Expand All @@ -371,7 +379,7 @@ def circular_hv(
return hv


def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
def bind(input: Tensor, other: Tensor) -> Tensor:
r"""Binds two hypervectors which produces a hypervector dissimilar to both.

Binding is used to associate information, for instance, to assign values to variables.
Expand All @@ -385,7 +393,6 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
Args:
input (Tensor): input hypervector
other (Tensor): other input hypervector
out (Tensor, optional): the output tensor.

Shapes:
- Input: :math:`(*)`
Expand All @@ -402,18 +409,21 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
tensor([ 1., -1., -1.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
dtype = input.dtype

if input.dtype == torch.uint8:
if torch.is_complex(input):
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.mul(input, other, out=out)
if dtype == torch.bool:
return torch.logical_xor(input, other)

return torch.mul(input, other)

def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:

def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor:
r"""Bundles two hypervectors which produces a hypervector maximally similar to both.

The bundling operation is used to aggregate information into a single hypervector.
Expand All @@ -427,7 +437,7 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
Args:
input (Tensor): input hypervector
other (Tensor): other input hypervector
out (Tensor, optional): the output tensor.
tie (BoolTensor, optional): specifies how to break a tie while bundling boolean hypervectors. Default: only set bit if both ``input`` and ``other`` are ``True``.

Shapes:
- Input: :math:`(*)`
Expand All @@ -444,15 +454,21 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
tensor([0., 2., 0.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
dtype = input.dtype

if input.dtype == torch.uint8:
if torch.is_complex(input):
raise NotImplementedError("Complex hypervectors are not supported yet.")

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.add(input, other, out=out)
if dtype == torch.bool:
if tie is not None:
return torch.where(input == other, input, tie)
else:
return torch.logical_and(input, other)

return torch.add(input, other)


def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
Expand Down Expand Up @@ -484,15 +500,19 @@ def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
tensor([ -1., 1., -1.])

"""
dtype = input.dtype

if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.roll(input, shifts=shifts, dims=dims)


def soft_quantize(input: Tensor, *, out=None):
def soft_quantize(input: Tensor):
"""Applies the hyperbolic tanh function to all elements of the input tensor.

Args:
input (Tensor): input tensor.
out (Tensor, optional): output tensor. Defaults to None.

Shapes:
- Input: :math:`(*)`
Expand All @@ -508,15 +528,14 @@ def soft_quantize(input: Tensor, *, out=None):
tensor([0.0000, 0.9640, 0.0000])

"""
return torch.tanh(input, out=out)
return torch.tanh(input)


def hard_quantize(input: Tensor, *, out=None):
def hard_quantize(input: Tensor):
"""Applies binary quantization to all elements of the input tensor.

Args:
input (Tensor): input tensor
out (Tensor, optional): output tensor. Defaults to None.

Shapes:
- Input: :math:`(*)`
Expand All @@ -537,13 +556,7 @@ def hard_quantize(input: Tensor, *, out=None):
positive = torch.tensor(1.0, dtype=input.dtype, device=input.device)
negative = torch.tensor(-1.0, dtype=input.dtype, device=input.device)

if out != None:
out[:] = torch.where(input > 0, positive, negative)
result = out
else:
result = torch.where(input > 0, positive, negative)

return result
return torch.where(input > 0, positive, negative)


def cosine_similarity(input: Tensor, others: Tensor) -> Tensor:
Expand Down Expand Up @@ -650,16 +663,21 @@ def multiset(input: Tensor) -> Tensor:
tensor([-1., 3., 1.])

"""
dim = -2
dtype = input.dtype

if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if input.dtype == torch.uint8:
if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

return torch.sum(input, dim=-2, dtype=input.dtype)
if dtype == torch.bool:
count = torch.sum(input, dim=dim, dtype=torch.long)
threshold = input.size(dim) // 2
return torch.greater(count, threshold)

return torch.sum(input, dim=dim, dtype=dtype)


multibundle = multiset
Expand All @@ -681,6 +699,10 @@ def multibind(input: Tensor) -> Tensor:
- Input: :math:`(*, n, d)`
- Output: :math:`(*, d)`

.. note::

This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.

Examples::

>>> x = functional.random_hv(3, 3)
Expand All @@ -692,14 +714,21 @@ def multibind(input: Tensor) -> Tensor:
tensor([ 1., 1., -1.])

"""
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
raise NotImplementedError(
"Boolean, and Complex hypervectors are not supported yet."
)
if input.dtype in {torch.complex64, torch.complex128}:
raise NotImplementedError("Complex hypervectors are not supported yet.")

if input.dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

if input.dtype == torch.bool:
hvs = torch.unbind(input, -2)
result = hvs[0]

for i in range(1, len(hvs)):
result = torch.logical_xor(result, hvs[i])

return result

return torch.prod(input, dim=-2, dtype=input.dtype)


Expand Down Expand Up @@ -870,6 +899,10 @@ def bind_sequence(input: Tensor) -> Tensor:
- Input: :math:`(*, n, d)`
- Output: :math:`(*, d)`

.. note::

This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.

Examples::

>>> x = functional.random_hv(5, 3)
Expand Down
Empty file.
Loading