Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Encodings
sequence
distinct_sequence
hash_table
cross_product
ngrams


Expand Down
40 changes: 40 additions & 0 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"dot_similarity",
"multiset",
"multibind",
"cross_product",
"sequence",
"distinct_sequence",
"ngrams",
Expand Down Expand Up @@ -629,6 +630,45 @@ def multibind(input: Tensor, *, dim=-2, keepdim=False, dtype=None, out=None) ->
return torch.prod(input, dim=dim, keepdim=keepdim, dtype=dtype, out=out)


def cross_product(input: Tensor, other: Tensor) -> Tensor:
r"""Cross product between two sets of hypervectors.

First creates a multiset from both tensors ``input`` (:math:`A`) and ``other`` (:math:`B`).
Then binds those together to generate all cross products, i.e., :math:`A_1 * B_1 + A_1 * B_2 + \dots + A_1 * B_m + \dots + A_n * B_m`.

.. math::

\big( \bigoplus_{i=0}^{n-1} A_i \big) \otimes \big( \bigoplus_{i=0}^{m-1} B_i \big)

Args:
input (Tensor): first set of input hypervectors
other (Tensor): second set of input hypervectors

Shapes:
- Input: :math:`(*, n, d)`
- Other: :math:`(*, m, d)`
- Output: :math:`(*, d)`

Examples::

>>> a = functional.random_hv(2, 3)
>>> a
tensor([[ 1., 1., -1.],
[ 1., -1., 1.]])
>>> b = functional.random_hv(5, 3)
>>> b
tensor([[ 1., -1., 1.],
[-1., -1., -1.],
[-1., -1., -1.],
[ 1., 1., -1.],
[ 1., -1., -1.]])
>>> functional.cross_product(a, b)
tensor([2., -0., -0.])

"""
return bind(multiset(input), multiset(other))


def ngrams(input: Tensor, n: int = 3) -> Tensor:
r"""Creates a hypervector with the :math:`n`-gram statistics of the input.

Expand Down