Skip to content

Add tests for the models module #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions torchhd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Centroid(nn.Module):
out_features (int): Size of the output, typically the number of classes.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.

Shape:
- Input: :math:`(*, d)` where :math:`*` means any number of
Expand All @@ -76,7 +77,12 @@ class Centroid(nn.Module):
weight: Tensor

def __init__(
self, in_features: int, out_features: int, device=None, dtype=None
self,
in_features: int,
out_features: int,
device=None,
dtype=None,
requires_grad=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(Centroid, self).__init__()
Expand All @@ -85,7 +91,7 @@ def __init__(
self.out_features = out_features

weight = torch.empty((out_features, in_features), **factory_kwargs)
self.weight = Parameter(weight)
self.weight = Parameter(weight, requires_grad=requires_grad)
self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -161,6 +167,7 @@ class IntRVFL(nn.Module):
kappa (int, optional): Parameter of the clipping function limiting the range of values; used as the part of transforming input data.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.

Shape:
- Input: :math:`(*, d)` where :math:`*` means any number of
Expand Down Expand Up @@ -189,6 +196,7 @@ def __init__(
kappa: Optional[int] = None,
device=None,
dtype=None,
requires_grad=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(IntRVFL, self).__init__()
Expand All @@ -202,8 +210,12 @@ def __init__(
in_features, self.dimensions, **factory_kwargs
)

weight = torch.zeros((out_features, dimensions), **factory_kwargs)
self.weight = Parameter(weight)
weight = torch.empty((out_features, dimensions), **factory_kwargs)
self.weight = Parameter(weight, requires_grad=requires_grad)
self.reset_parameters()

def reset_parameters(self) -> None:
init.zeros_(self.weight)

def encode(self, x):
encodings = self.encoding(x)
Expand Down
111 changes: 111 additions & 0 deletions torchhd/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#
# MIT License
#
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import pytest
import torch
import torch.nn.functional as F
import torchhd
from torchhd import models
from torchhd import MAPTensor

from .utils import (
torch_dtypes,
vsa_tensors,
supported_dtype,
)


class TestCentroid:
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_initialization(self, dtype):
if dtype not in MAPTensor.supported_dtypes:
return

model = models.Centroid(1245, 12, dtype=dtype)
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
assert model.weight.dtype == dtype

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.Centroid(1245, 12, dtype=dtype, device=device)
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
assert model.weight.dtype == dtype
assert model.weight.device == device

def test_add(self):
samples = torch.randn(4, 12)
targets = torch.tensor([0, 1, 2, 2])

model = models.Centroid(12, 3)
model.add(samples, targets)

c = samples[:-1].clone()
c[-1] += samples[-1]

assert torch.allclose(model(samples), torchhd.cos(samples, c))
assert torch.allclose(model(samples, dot=True), torchhd.dot(samples, c))

model.normalize()
print(model(samples, dot=True))
print(torchhd.cos(samples, c))
assert torch.allclose(
model(samples, dot=True), torchhd.dot(samples, F.normalize(c))
)

def test_add_online(self):
samples = torch.randn(10, 12)
targets = torch.randint(0, 3, (10,))

model = models.Centroid(12, 3)
model.add_online(samples, targets)

logits = model(samples)
assert logits.shape == (10, 3)


class TestIntRVFL:
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_initialization(self, dtype):
if dtype not in MAPTensor.supported_dtypes:
return

model = models.IntRVFL(5, 1245, 12, dtype=dtype)
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
assert model.weight.dtype == dtype

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.IntRVFL(5, 1245, 12, dtype=dtype, device=device)
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
assert model.weight.dtype == dtype
assert model.weight.device == device

def test_fit_ridge_regression(self):
samples = torch.randn(10, 12)
targets = torch.randint(0, 3, (10,))

model = models.IntRVFL(12, 1245, 3)
model.fit_ridge_regression(samples, targets)

logits = model(samples)
assert logits.shape == (10, 3)