Skip to content

Fix dtype validation error in the Fractional Power Encoding #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!--- Provide a short summary of your changes in the Title above -->

## Description
<!--- Describe your changes in detail -->
<!-- Link the issue (if any) that will be resolved by the changes -->



## Checklist
- [ ] I added/updated documentation for the changes.
- [ ] I have thoroughly tested the changes.
12 changes: 9 additions & 3 deletions torchhd/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import torchhd.functional as functional
from torchhd.tensors.base import VSATensor
from torchhd.tensors.map import MAPTensor
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.fhrr import FHRRTensor, type_conversion as fhrr_type_conversion
from torchhd.tensors.hrr import HRRTensor
from torchhd.types import VSAOptions

Expand Down Expand Up @@ -1017,7 +1017,6 @@ def __init__(
dtype=None,
requires_grad: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(FractionalPower, self).__init__()

self.in_features = in_features # data dimensions
Expand All @@ -1032,9 +1031,16 @@ def __init__(

self.vsa_tensor = functional.get_vsa_tensor_class(vsa)

if dtype not in self.vsa_tensor.supported_dtypes:
# If a specific dtype is specified make sure it is supported by the VSA model
if dtype != None and dtype not in self.vsa_tensor.supported_dtypes:
raise ValueError(f"dtype {dtype} not supported by {vsa}")

# The internal weights/phases are stored as floats even if the output is a complex tensor
if dtype != None and vsa == "FHRR":
dtype = fhrr_type_conversion[dtype]

factory_kwargs = {"device": device, "dtype": dtype}

# If the distribution is a string use the presets in predefined_kernels
if isinstance(distribution, str):
try:
Expand Down
110 changes: 110 additions & 0 deletions torchhd/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
#
import pytest
import torch
import math

import torchhd
from torchhd import functional
from torchhd import embeddings
from torchhd.tensors.hrr import HRRTensor
from torchhd.tensors.fhrr import type_conversion as fhrr_type_conversion


from .utils import (
Expand Down Expand Up @@ -540,3 +542,111 @@ def test_value(self, vsa):
)
> 0.99
)


class TestFractionalPower:
@pytest.mark.parametrize("vsa", vsa_tensors)
def test_default_dtype(self, vsa):
dimensions = 1000
embedding = 10

if vsa not in {"HRR", "FHRR"}:
with pytest.raises(ValueError):
embeddings.FractionalPower(embedding, dimensions, vsa=vsa)

return

emb = embeddings.FractionalPower(embedding, dimensions, vsa=vsa)
x = torch.randn(2, embedding)
y = emb(x)
assert y.shape == (2, dimensions)

if vsa == "HRR":
assert y.dtype == torch.float32
elif vsa == "FHRR":
assert y.dtype == torch.complex64
else:
return

@pytest.mark.parametrize("dtype", torch_dtypes)
def test_dtype(self, dtype):
dimensions = 1456
embedding = 2

if dtype not in {torch.float32, torch.float64}:
with pytest.raises(ValueError):
embeddings.FractionalPower(
embedding, dimensions, vsa="HRR", dtype=dtype
)
else:
emb = embeddings.FractionalPower(
embedding, dimensions, vsa="HRR", dtype=dtype
)

x = torch.randn(13, embedding, dtype=dtype)
y = emb(x)
assert y.shape == (13, dimensions)
assert y.dtype == dtype

if dtype not in {torch.complex64, torch.complex128}:
with pytest.raises(ValueError):
embeddings.FractionalPower(
embedding, dimensions, vsa="FHRR", dtype=dtype
)
else:
emb = embeddings.FractionalPower(
embedding, dimensions, vsa="FHRR", dtype=dtype
)

x = torch.randn(13, embedding, dtype=fhrr_type_conversion[dtype])
y = emb(x)
assert y.shape == (13, dimensions)
assert y.dtype == dtype

def test_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

emb = embeddings.FractionalPower(35, 1000, "gaussian", device=device)

x = torchhd.random(5, 35, device=device)
y = emb(x)
assert y.shape == (5, 1000)
assert y.device.type == device.type

def test_custom_dist_iid(self):
kernel_shape = torch.distributions.Normal(0, 1)
band = 3.0

emb = embeddings.FractionalPower(3, 1000, kernel_shape, band)
x = torch.randn(1, 3)
y = emb(x)
assert y.shape == (1, 1000)

def test_custom_dist_2d(self):
# Phase distribution for periodic Sinc kernel
class HexDisc(torch.distributions.Categorical):
def __init__(self):
super().__init__(torch.ones(6))
self.r = 1
self.side = self.r * math.sqrt(3) / 2
self.phases = torch.tensor(
[
[-self.r, 0.0],
[-self.r / 2, self.side],
[self.r / 2, self.side],
[self.r, 0.0],
[self.r / 2, -self.side],
[-self.r / 2, -self.side],
]
)

def sample(self, sample_shape=torch.Size()):
return self.phases[super().sample(sample_shape), :]

kernel_shape = HexDisc()
band = 3.0

emb = embeddings.FractionalPower(2, 1000, kernel_shape, band)
x = torch.randn(5, 2)
y = emb(x)
assert y.shape == (5, 1000)