Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Encodings
hash_table
cross_product
ngrams
graph


Utilities
Expand Down
63 changes: 51 additions & 12 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"bind_sequence",
"ngrams",
"hash_table",
"graph",
"map_range",
"value_to_index",
"index_to_value",
Expand All @@ -41,7 +42,7 @@ def identity_hv(
device=None,
requires_grad=False,
) -> Tensor:
"""Creates a set of identity hypervector.
"""Creates a set of identity hypervectors.

When bound with a random-hypervector :math:`x`, the result is :math:`x`.

Expand Down Expand Up @@ -174,24 +175,20 @@ def random_hv(
if dtype == torch.uint8:
raise ValueError("Unsigned integer hypervectors are not supported.")

size = (num_embeddings, embedding_dim)
if dtype in {torch.complex64, torch.complex128}:
dtype = torch.float if dtype == torch.complex64 else torch.double

angle = torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)
angle = torch.empty(size, dtype=dtype, device=device)
angle.uniform_(-math.pi, math.pi)
magnitude = torch.ones(
num_embeddings, embedding_dim, dtype=dtype, device=device
)
magnitude = torch.ones(size, dtype=dtype, device=device)

result = torch.polar(magnitude, angle)
result.requires_grad = requires_grad
return result

select = torch.empty(
(
num_embeddings,
embedding_dim,
),
size,
dtype=torch.bool,
).bernoulli_(1.0 - sparsity, generator=generator)

Expand Down Expand Up @@ -1031,7 +1028,7 @@ def hash_table(keys: Tensor, values: Tensor) -> Tensor:

.. math::

\bigoplus_{i = 0}^{m - 1} K_i \otimes V_i
\bigoplus_{i = 0}^{n - 1} K_i \otimes V_i

Args:
keys (Tensor): The keys hypervectors, must be the same shape as values.
Expand Down Expand Up @@ -1066,7 +1063,7 @@ def bundle_sequence(input: Tensor) -> Tensor:

.. math::

\bigoplus_{i=0}^{m-1} \Pi^{m - i - 1}(V_i)
\bigoplus_{i=0}^{n-1} \Pi^{n - i - 1}(V_i)

Args:
input (Tensor): The hypervector values.
Expand Down Expand Up @@ -1105,7 +1102,7 @@ def bind_sequence(input: Tensor) -> Tensor:

.. math::

\bigotimes_{i=0}^{m-1} \Pi^{m - i - 1}(V_i)
\bigotimes_{i=0}^{n-1} \Pi^{n - i - 1}(V_i)

Args:
input (Tensor): The hypervector values.
Expand Down Expand Up @@ -1141,6 +1138,48 @@ def bind_sequence(input: Tensor) -> Tensor:
return multibind(permuted)


def graph(input: Tensor, *, directed=False) -> Tensor:
r"""Graph from node hypervector pairs.

If ``directed=False`` this computes:

.. math::

\bigoplus_{i = 0}^{n - 1} V_{0,i} \otimes V_{1,i}

If ``directed=True`` this computes:

.. math::

\bigoplus_{i = 0}^{n - 1} V_{0,i} \otimes \Pi(V_{1,i})

Args:
input (Tensor): tensor containing pairs of node hypervectors that share an edge.
directed (bool, optional): specify if the graph is directed or not. Default: ``False``.

Shapes:
- Input: :math:`(*, 2, n, d)`
- Output: :math:`(*, d)`

Examples::
>>> edges = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 3]])
>>> node_embedding = embeddings.Random(4, 10000)
>>> edges_hv = node_embedding(edges)
>>> graph = functional.graph(edges_hv)
>>> neighbors = unbind(graph, node_embedding.weight[0])
>>> cosine_similarity(neighbors, node_embedding.weight)
tensor([0.0006, 0.5017, 0.4997, 0.0048])

"""
to_nodes = input[..., 0, :, :]
from_nodes = input[..., 1, :, :]

if directed:
from_nodes = permute(from_nodes)

return multiset(bind(to_nodes, from_nodes))


def map_range(
input: Tensor,
in_min: float,
Expand Down
20 changes: 20 additions & 0 deletions torchhd/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,26 @@ def clear(self) -> None:
"""
self.value.fill_(0.0)

@classmethod
def from_edges(cls, input: Tensor, directed=False):
"""Creates a graph from a tensor

See: :func:`~torchhd.functional.graph`.

Args:
input (Tensor): tensor containing pairs of node hypervectors that share an edge.
directed (bool, optional): specify if the graph is directed or not. Default: ``False``.

Examples::
>>> edges = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 3]])
>>> node_embedding = embeddings.Random(4, 10000)
>>> edges_hv = node_embedding(edges)
>>> graph = structures.Graph.from_edges(edges_hv)

"""
value = functional.graph(input, directed=directed)
return cls(value, directed=directed)


class Tree:
"""Hypervector-based tree data structure.
Expand Down
22 changes: 22 additions & 0 deletions torchhd/tests/structures/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,25 @@ def test_clear(self):
assert torch.equal(
G.value, torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
)

def test_from_edges(self):
generator = torch.Generator()
generator.manual_seed(seed)

hv = functional.random_hv(4, 8, generator=generator)
edges = torch.empty(2, 3, 8)
edges[0, 0] = hv[0]
edges[1, 0] = hv[1]
edges[0, 1] = hv[0]
edges[1, 1] = hv[2]
edges[0, 2] = hv[1]
edges[1, 2] = hv[2]

G = structures.Graph.from_edges(edges)
neighbors = G.node_neighbors(hv[0])
neighbor_similarity = functional.cosine_similarity(neighbors, hv)

assert neighbor_similarity[0] < torch.tensor(0.5)
assert neighbor_similarity[1] > torch.tensor(0.5)
assert neighbor_similarity[2] > torch.tensor(0.5)
assert neighbor_similarity[3] < torch.tensor(0.5)
41 changes: 41 additions & 0 deletions torchhd/tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,44 @@ def test_device(self):
hv = functional.random_hv(11, 10000, device=device)
res = functional.bind_sequence(hv)
assert res.device == device


class TestGraph:
def test_value(self):
hv = torch.zeros(2, 4, 1000)
res = functional.graph(hv)
assert torch.all(res == 0).item()

g = torch.tensor(
[
[[1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, 1, 1]],
[[-1, -1, 1, 1], [-1, 1, 1, 1], [1, -1, -1, 1]],
]
)
res = functional.graph(g)
assert torch.all(res == torch.tensor([-1, -1, -1, 3])).item()
assert res.dtype == g.dtype

res = functional.graph(g, directed=True)
assert torch.all(res == torch.tensor([-1, 3, 1, 1])).item()
assert res.dtype == g.dtype

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
hv = torch.zeros(5, 2, 23, 1000, dtype=dtype)

if dtype == torch.uint8:
with pytest.raises(ValueError):
functional.graph(hv)

return

res = functional.graph(hv)
assert res.dtype == dtype

def test_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hv = torch.zeros(5, 2, 23, 1000, device=device)
res = functional.graph(hv)
assert res.device == device
16 changes: 8 additions & 8 deletions torchhd/tests/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class TestDotSimilarity:
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_shape(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_shape(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_value(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_value(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand All @@ -134,7 +134,7 @@ def test_dtype(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_device(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand All @@ -153,7 +153,7 @@ def test_device(self, dtype):
class TestCosSimilarity:
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_shape(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_shape(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_value(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_value(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand All @@ -264,7 +264,7 @@ def test_dtype(self, dtype):

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_device(self, dtype):
if not supported_dtype(dtype):
if not supported_dtype(dtype) or dtype == torch.half:
return

generator = torch.Generator()
Expand Down