Skip to content

Commit c58fc06

Browse files
authored
Add cross product function (#46)
1 parent b1aacc3 commit c58fc06

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

docs/functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Encodings
5858
sequence
5959
distinct_sequence
6060
hash_table
61+
cross_product
6162
ngrams
6263

6364

torchhd/functional.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"dot_similarity",
2323
"multiset",
2424
"multibind",
25+
"cross_product",
2526
"sequence",
2627
"distinct_sequence",
2728
"ngrams",
@@ -629,6 +630,45 @@ def multibind(input: Tensor, *, dim=-2, keepdim=False, dtype=None, out=None) ->
629630
return torch.prod(input, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
630631

631632

633+
def cross_product(input: Tensor, other: Tensor) -> Tensor:
634+
r"""Cross product between two sets of hypervectors.
635+
636+
First creates a multiset from both tensors ``input`` (:math:`A`) and ``other`` (:math:`B`).
637+
Then binds those together to generate all cross products, i.e., :math:`A_1 * B_1 + A_1 * B_2 + \dots + A_1 * B_m + \dots + A_n * B_m`.
638+
639+
.. math::
640+
641+
\big( \bigoplus_{i=0}^{n-1} A_i \big) \otimes \big( \bigoplus_{i=0}^{m-1} B_i \big)
642+
643+
Args:
644+
input (Tensor): first set of input hypervectors
645+
other (Tensor): second set of input hypervectors
646+
647+
Shapes:
648+
- Input: :math:`(*, n, d)`
649+
- Other: :math:`(*, m, d)`
650+
- Output: :math:`(*, d)`
651+
652+
Examples::
653+
654+
>>> a = functional.random_hv(2, 3)
655+
>>> a
656+
tensor([[ 1., 1., -1.],
657+
[ 1., -1., 1.]])
658+
>>> b = functional.random_hv(5, 3)
659+
>>> b
660+
tensor([[ 1., -1., 1.],
661+
[-1., -1., -1.],
662+
[-1., -1., -1.],
663+
[ 1., 1., -1.],
664+
[ 1., -1., -1.]])
665+
>>> functional.cross_product(a, b)
666+
tensor([2., -0., -0.])
667+
668+
"""
669+
return bind(multiset(input), multiset(other))
670+
671+
632672
def ngrams(input: Tensor, n: int = 3) -> Tensor:
633673
r"""Creates a hypervector with the :math:`n`-gram statistics of the input.
634674

0 commit comments

Comments
 (0)