Skip to content

Commit eba0620

Browse files
mikeheddesDidanny
andauthored
Add sinusoid embedding and fix projection embedding (#88)
* Fix projection embedding and implement cosine-based * WIP * Add sinusoid encoding and fix projection encoding Co-authored-by: Didanny <daa50@mail.aub.edu>
1 parent dcb9929 commit eba0620

File tree

4 files changed

+150
-9
lines changed

4 files changed

+150
-9
lines changed

docs/embeddings.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ torchhd.embeddings
1313
Random
1414
Level
1515
Circular
16-
Projection
16+
Projection
17+
Sinusoid

sandbox.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
4+
from torchhd import functional
5+
from torchhd import embeddings
6+
7+
# torch.float32,
8+
# torch.float64,
9+
# torch.complex64,
10+
# torch.complex128,
11+
# torch.float16,
12+
# torch.bfloat16,
13+
# torch.uint8,
14+
# torch.int8,
15+
# torch.int16,
16+
# torch.int32,
17+
# torch.int64,
18+
# torch.bool,
19+
20+
# from .utils import (
21+
# torch_dtypes,
22+
# torch_complex_dtypes,
23+
# supported_dtype,
24+
# )
25+
26+
for i in range(5, 20):
27+
emb = embeddings.Identity(i, 3)
28+
idx = torch.LongTensor([0, 1, 4])
29+
res = emb(idx)
30+
31+
print("{0},{1}".format(res.size(dim=0),res.size(dim=1)))

torchhd/embeddings.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"Level",
1212
"Circular",
1313
"Projection",
14+
"Sinusoid",
1415
]
1516

1617

@@ -230,7 +231,7 @@ class Projection(nn.Module):
230231
r"""Embedding using a random projection matrix.
231232
232233
Implemented based on `A Theoretical Perspective on Hyperdimensional Computing <https://arxiv.org/abs/2010.07426>`_.
233-
:math:`\Phi x` where :math:`\Phi \in \mathbb{R}^{d \times m}` is a matrix whose rows are uniformly sampled at random from the surface of an :math:`m`-dimensional unit sphere.
234+
It computes :math:`x \Phi^{\mathsf{T}}` where :math:`\Phi \in \mathbb{R}^{d \times m}` is a matrix whose rows are uniformly sampled at random from the surface of an :math:`d`-dimensional unit sphere.
234235
This encoding ensures that similarities in the input space are preserved in the hyperspace.
235236
236237
Args:
@@ -242,11 +243,16 @@ class Projection(nn.Module):
242243
243244
Examples::
244245
245-
>>> emb = embeddings.Projection(5, 3)
246-
>>> x = torch.rand(2, 5)
247-
>>> emb(x)
248-
tensor([[ 0.2747, -0.8804, -0.6810],
249-
[ 0.5610, -0.9227, 0.1671]])
246+
>>> embed = embeddings.Projection(6, 5)
247+
>>> x = torch.randn(3, 6)
248+
>>> x
249+
tensor([[ 0.4119, -0.4284, 1.8022, 0.3715, -1.4563, -0.2842],
250+
[-0.3772, -1.2664, -1.5173, 1.3317, 0.4707, -1.3362],
251+
[-1.8142, 0.0274, -1.0989, 0.8193, 0.7619, 0.9181]])
252+
>>> embed(x).sign()
253+
tensor([[-1., 1., 1., 1., 1.],
254+
[ 1., 1., 1., 1., 1.],
255+
[ 1., -1., -1., -1., -1.]])
250256
251257
"""
252258

@@ -270,8 +276,70 @@ def __init__(
270276
self.reset_parameters()
271277

272278
def reset_parameters(self) -> None:
273-
nn.init.uniform_(self.weight, -1, 1)
274-
self.weight.data[:] = F.normalize(self.weight.data)
279+
nn.init.normal_(self.weight, 0, 1)
280+
self.weight.data.copy_(F.normalize(self.weight.data))
275281

276282
def forward(self, input: torch.Tensor) -> torch.Tensor:
277283
return F.linear(input, self.weight)
284+
285+
286+
class Sinusoid(nn.Module):
287+
r"""Embedding using a nonlinear random projection
288+
289+
Implemented based on `Scalable Edge-Based Hyperdimensional Learning System with Brain-Like Neural Adaptation <https://dl.acm.org/doi/abs/10.1145/3458817.3480958>`_.
290+
It computes :math:`\cos(x \Phi^{\mathsf{T}} + b) \odot \sin(x \Phi^{\mathsf{T}})` where :math:`\Phi \in \mathbb{R}^{d \times m}` is a matrix whose elements are sampled at random from a standard normal distribution and :math:`b \in \mathbb{R}^{d}` is a vectors whose elements are sampled uniformly at random between 0 and :math:`2\pi`.
291+
292+
Args:
293+
in_features (int): the dimensionality of the input feature vector.
294+
out_features (int): the dimensionality of the hypervectors.
295+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
296+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see ``torch.set_default_tensor_type()``).
297+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
298+
299+
Examples::
300+
301+
>>> embed = embeddings.Sinusoid(6, 5)
302+
>>> x = torch.randn(3, 6)
303+
>>> x
304+
tensor([[ 0.5043, 0.3161, -0.0938, 0.6134, -0.1280, 0.3647],
305+
[-0.1907, 1.6468, -0.3242, 0.8614, 0.3332, -0.2055],
306+
[-0.8662, -1.3861, -0.1577, 0.1321, -0.1157, -2.8928]])
307+
>>> embed(x)
308+
tensor([[-0.0555, 0.2292, -0.1833, 0.0301, -0.2416],
309+
[-0.0725, 0.7042, -0.5644, 0.2235, 0.3603],
310+
[-0.9021, 0.8899, -0.9802, 0.3565, 0.2367]])
311+
312+
"""
313+
314+
__constants__ = ["in_features", "out_features"]
315+
in_features: int
316+
out_features: int
317+
weight: torch.Tensor
318+
bias: torch.Tensor
319+
320+
def __init__(
321+
self, in_features, out_features, requires_grad=False, device=None, dtype=None
322+
):
323+
factory_kwargs = {"device": device, "dtype": dtype}
324+
super(Sinusoid, self).__init__()
325+
self.in_features = in_features
326+
self.out_features = out_features
327+
328+
self.weight = nn.parameter.Parameter(
329+
torch.empty((out_features, in_features), **factory_kwargs),
330+
requires_grad=requires_grad,
331+
)
332+
333+
self.bias = nn.parameter.Parameter(
334+
torch.empty((1, out_features), **factory_kwargs),
335+
requires_grad=requires_grad,
336+
)
337+
self.reset_parameters()
338+
339+
def reset_parameters(self) -> None:
340+
nn.init.normal_(self.weight, 0, 1)
341+
nn.init.uniform_(self.bias, 0, 2*math.pi)
342+
343+
def forward(self, input: torch.Tensor) -> torch.Tensor:
344+
projected = F.linear(input, self.weight)
345+
return torch.cos(projected + self.bias) * torch.sin(projected)

torchhd/tests/test_embeddings.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
import torch
3+
4+
from torchhd import functional
5+
from torchhd import embeddings
6+
7+
from .utils import (
8+
torch_dtypes,
9+
torch_complex_dtypes,
10+
supported_dtype,
11+
)
12+
13+
# class TestIdentity:
14+
# def test_num_embeddings(self):
15+
# for i in range(1, 10):
16+
# emb = embeddings.Identity(i, 3)
17+
# idx = torch.LongTensor([0, 1, 4])
18+
# res = emb(idx)
19+
20+
# assert res.size != i
21+
# assert True
22+
23+
# def test_embedding_dim(self):
24+
# assert True
25+
26+
# def test_value(self):
27+
# assert True
28+
29+
# class TestRandom:
30+
# @pytest.mark.parametrize("dtype", torch_dtypes)
31+
# def test_num_embeddings(self, dtype):
32+
# assert True
33+
34+
# @pytest.mark.parametrize("dtype", torch_dtypes)
35+
# def test_embedding_dim(self, dtype):
36+
# assert True
37+
38+
# @pytest.mark.parametrize("dtype", torch_dtypes)
39+
# def test_value(self, dtype):
40+
# assert True
41+

0 commit comments

Comments
 (0)