Skip to content

Commit 4821674

Browse files
denklegithub-actions[bot]mikeheddes
authored
Fractional power encoding (#142)
* Basic start for fractiona power encoding * [github-action] formatting fixes * Address some of the initial design issues * [github-action] formatting fixes * Revised designed and moved to embeddings * [github-action] formatting fixes * Added support for HRR model and a tutorial-like examples for several kernels * [github-action] formatting fixes * Refactor FractionalPower embedding * [github-action] formatting fixes * Changed range for histogram in the notebook * Update docs * Match dtype and device of embedding * [github-action] formatting fixes * Verify if dtype is supported --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Mike Heddes <mikeheddes@gmail.com>
1 parent 6dcfe9c commit 4821674

File tree

4 files changed

+1306
-2
lines changed

4 files changed

+1306
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ dmypy.json
134134

135135
# Pyre type checker
136136
.pyre/
137+
138+
# MacOS
139+
.DS_Store

docs/embeddings.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ torchhd.embeddings
1717
Circular
1818
Projection
1919
Sinusoid
20-
Density
20+
Density
21+
FractionalPower

examples/fractional_power_encoding_kernels.ipynb

Lines changed: 1105 additions & 0 deletions
Large diffs are not rendered by default.

torchhd/embeddings.py

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# SOFTWARE.
2323
#
2424
import math
25-
from typing import Type, Union, Optional
25+
from typing import Type, Union, Optional, Literal, Callable
2626
import torch
2727
import torch.nn as nn
2828
import torch.nn.functional as F
@@ -32,8 +32,11 @@
3232
import torchhd.functional as functional
3333
from torchhd.tensors.base import VSATensor
3434
from torchhd.tensors.map import MAPTensor
35+
from torchhd.tensors.fhrr import FHRRTensor
36+
from torchhd.tensors.hrr import HRRTensor
3537
from torchhd.types import VSAOptions
3638

39+
3740
__all__ = [
3841
"Empty",
3942
"Identity",
@@ -44,6 +47,7 @@
4447
"Projection",
4548
"Sinusoid",
4649
"Density",
50+
"FractionalPower",
4751
]
4852

4953

@@ -966,3 +970,194 @@ def forward(self, input: Tensor) -> Tensor:
966970
output = functional.bind(self.key.weight, self.density_encoding(input))
967971
# Perform the superposition operation on the bound key-value pairs
968972
return functional.multibundle(output)
973+
974+
975+
class FractionalPower(nn.Module):
976+
"""Class for fractional power encoding (FPE) method that forms hypervectors for given values, kernel shape, bandwidth, and dimensionality. Implements similarity-preserving hypervectors approximating desired kernel shape as described in `Computing on Functions Using Randomized Vector Representations <https://arxiv.org/abs/2109.03429>`_.
977+
978+
Args:
979+
in_features (int): the dimensionality of the input feature vector.
980+
out_features (int): the dimensionality of the hypervectors.
981+
distribution (str, optional): hyperparameter defining the shape of the kernel by specifying a particular probability distribution that is used to sample the base hypervector(s). Default: ``"sinc"``.
982+
bandwidth (float, optional): positive hyperparameter defining the width of the similarity kernel. Lower values lead to broader kernels while larger values lead to more narrow kernels. Default: ``1.0``.
983+
vsa: (``VSAOptions``, optional): specifies the hypervector type to be instantiated. Default: ``"FHRR"``.
984+
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` depends on VSATensor.
985+
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
986+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
987+
988+
Examples::
989+
990+
>>> embed = embeddings.FractionalPower(1, 6, "sinc", 1.0, "FHRR")
991+
>>> embed(torch.arange(1, 4, 1.).view(-1, 1))
992+
FHRRTensor([[-0.7181-0.6959j, -0.5269+0.8499j, -0.0848+0.9964j, 0.9720-0.2348j,
993+
0.6358+0.7718j, 0.4352+0.9003j],
994+
[ 0.0314+0.9995j, -0.4447-0.8957j, -0.9856-0.1689j, 0.8897-0.4565j,
995+
-0.1915+0.9815j, -0.6212+0.7836j],
996+
[ 0.6730-0.7396j, 0.9956+0.0940j, 0.2519-0.9678j, 0.7576-0.6527j,
997+
-0.8793+0.4762j, -0.9759-0.2183j]])
998+
999+
"""
1000+
1001+
# The collection of distributions for basic predefined kernels
1002+
predefined_kernels = {
1003+
"sinc": torch.distributions.Uniform(-math.pi, math.pi),
1004+
"gaussian": torch.distributions.Normal(0.0, 1.0),
1005+
}
1006+
1007+
def __init__(
1008+
self,
1009+
in_features: int,
1010+
out_features: int,
1011+
distribution: Union[
1012+
torch.distributions.Distribution, Literal["sinc", "gaussian"]
1013+
] = "sinc",
1014+
bandwidth: float = 1.0,
1015+
vsa: Literal["HRR", "FHRR"] = "FHRR",
1016+
device=None,
1017+
dtype=None,
1018+
requires_grad: bool = False,
1019+
) -> None:
1020+
factory_kwargs = {"device": device, "dtype": dtype}
1021+
super(FractionalPower, self).__init__()
1022+
1023+
self.in_features = in_features # data dimensions
1024+
self.out_features = out_features # hypervector dimensions
1025+
self.bandwidth = bandwidth
1026+
self.requires_grad = requires_grad
1027+
1028+
if vsa not in {"HRR", "FHRR"}:
1029+
raise ValueError(
1030+
f"FractionalPower embedding only supports HRR and FHRR but provided: {vsa}"
1031+
)
1032+
1033+
self.vsa_tensor = functional.get_vsa_tensor_class(vsa)
1034+
1035+
if dtype not in self.vsa_tensor.supported_dtypes:
1036+
raise ValueError(f"dtype {dtype} not supported by {vsa}")
1037+
1038+
# If the distribution is a string use the presets in predefined_kernels
1039+
if isinstance(distribution, str):
1040+
try:
1041+
self.distribution = self.predefined_kernels[distribution]
1042+
except KeyError:
1043+
available_names = ", ".join(list(self.predefined_kernels.keys()))
1044+
raise ValueError(
1045+
f"{distribution} kernel is not supported, use one of: {available_names}, or provide a custom torch distribution."
1046+
)
1047+
else:
1048+
self.distribution = distribution
1049+
1050+
# Initialize encoding's parameters
1051+
self.weight = nn.Parameter(
1052+
torch.empty(self.out_features, self.in_features, **factory_kwargs),
1053+
requires_grad,
1054+
)
1055+
self.reset_parameters()
1056+
1057+
# Sample the angles using the provided distribution
1058+
def reset_parameters(self) -> None:
1059+
"""Generate the angles for basis hypervector(s) to be used for encoding the data."""
1060+
1061+
sample_shape = self.distribution.sample().shape
1062+
1063+
# Check HD/VSA model type
1064+
if self.vsa_tensor == FHRRTensor:
1065+
# Generate the angles for base hypervector(s) that determines the shape of the FPE kernel
1066+
# If the distribution is one-dimensional this implies that base hypervectors are independent so it is safe to generate self.in_features * self.out_features independent samples
1067+
if sample_shape == ():
1068+
# Draw angles from a uniform distribution for base hypervector(s). Note that data dimensions here are independent but this does not have to be always the case
1069+
phases = self.distribution.sample((self.out_features, self.in_features))
1070+
phases = phases.to(self.weight)
1071+
self.weight.data.copy_(phases)
1072+
1073+
# If base hypervectors are correlated then the dimensionality of the distribution should match that of the data
1074+
elif sample_shape == (self.in_features,):
1075+
phases = self.distribution.sample((self.out_features,))
1076+
phases = phases.to(self.weight)
1077+
self.weight.data.copy_(phases)
1078+
1079+
# Raise error due to the ambiguity of the situation
1080+
else:
1081+
raise ValueError(
1082+
f"The provided distribution has shape {sample_shape} while the input data expects shape () or ({self.in_features},) so there is a mismatch."
1083+
)
1084+
1085+
elif self.vsa_tensor == HRRTensor:
1086+
# Fewer angles are needed
1087+
dimensions_real = int((self.out_features - 1) / 2)
1088+
1089+
# Generate the angles for base hypervector(s) that determines the shape of the FPE kernel
1090+
# If the distribution is one-dimensional this implies that base hypervectors are independent so it is safe to generate self.in_features * self.out_features independent samples
1091+
if sample_shape == ():
1092+
# Draw angles from a uniform distribution for base hypervector(s). Note that data dimensions here are independent but this does not have to be always the case
1093+
phases = self.distribution.sample((dimensions_real, self.in_features))
1094+
1095+
# If base hypervectors are correlated then the dimensionality of the distribution should match that of the data
1096+
elif sample_shape == (self.in_features,):
1097+
phases = self.distribution.sample((dimensions_real,))
1098+
1099+
# Raise error due to the ambiguity of the situation
1100+
else:
1101+
raise ValueError(
1102+
f"The provided distribution has shape {sample_shape} while the input data expects shape () or ({self.in_features},) so there is a mismatch."
1103+
)
1104+
1105+
# Make the generated angles negatively symmetric so they look as a spectrum
1106+
phases = torch.cat(
1107+
(
1108+
phases,
1109+
torch.zeros(1, self.in_features),
1110+
-torch.flip(phases, dims=[0]),
1111+
),
1112+
dim=0,
1113+
)
1114+
if self.out_features % 2 == 0:
1115+
phases = torch.cat((torch.zeros(1, self.in_features), phases), dim=0)
1116+
1117+
phases = phases.to(self.weight)
1118+
# Set the generated angles to the object's parameters
1119+
self.weight.data.copy_(phases)
1120+
1121+
def basis(self):
1122+
"""Return the values of the base hypervector(s)"""
1123+
1124+
# Use the angles in self.weight to obtain the values of the base hypervector(s)
1125+
if self.vsa_tensor == FHRRTensor:
1126+
hvs = torch.complex(self.weight.cos(), self.weight.sin()).T
1127+
hvs = hvs.as_subclass(FHRRTensor)
1128+
1129+
elif self.vsa_tensor == HRRTensor:
1130+
complex_hv = torch.complex(self.weight.cos(), self.weight.sin()).T
1131+
hvs = torch.real(
1132+
torch.fft.ifft(torch.fft.ifftshift(complex_hv, dim=1), dim=1)
1133+
)
1134+
hvs = hvs.as_subclass(HRRTensor)
1135+
1136+
return hvs
1137+
1138+
def forward(self, input: Tensor) -> Tensor:
1139+
"""Creates a fractional power encoding (FPE) for given values.
1140+
1141+
Args:
1142+
input (Tensor): values for which FPE hypervectors should be generated. Either a vector or a batch of vectors.
1143+
1144+
Shapes:
1145+
- Input: :math:`(*, f)` where f is the in_features and * is an optional batch dimension.
1146+
- Output: :math:`(*, d)` where d is the out_features and * is an optional batch dimension.
1147+
1148+
"""
1149+
1150+
# Perform FPE of the desired values using the base hypervector(s)
1151+
# Simultaneously computes angles for given values and their sum that is equivalent to the binding
1152+
if self.vsa_tensor == FHRRTensor:
1153+
phases = F.linear(self.bandwidth * input, self.weight)
1154+
hv = torch.complex(phases.cos(), phases.sin())
1155+
hv = hv.as_subclass(FHRRTensor)
1156+
1157+
elif self.vsa_tensor == HRRTensor:
1158+
phases = F.linear(self.bandwidth * input, self.weight)
1159+
hv = torch.complex(phases.cos(), phases.sin())
1160+
hv = torch.real(torch.fft.ifft(torch.fft.ifftshift(hv, dim=1), dim=1))
1161+
hv = hv.as_subclass(HRRTensor)
1162+
1163+
return hv

0 commit comments

Comments
 (0)