Skip to content

Add basic data structures #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 3, 2022
Merged
Next Next commit
WIP: Add data structures
  • Loading branch information
mikeheddes committed Apr 25, 2022
commit 51709cd5826f179868fd9fb183b748b9b129fc2c
3 changes: 2 additions & 1 deletion docs/_templates/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
{{ name | underline}}

.. autoclass:: {{ name }}
:members:
:members:
:special-members:
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Welcome to hdc's documentation!

functional
embeddings
structures
datasets
metrics

Expand Down
12 changes: 12 additions & 0 deletions docs/structures.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Structures
==================

.. currentmodule:: hdc.structures

.. autosummary::
:toctree: generated/
:template: class.rst

Memory
Set
Sequence
2 changes: 2 additions & 0 deletions hdc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from . import functional
from . import structures
from . import embeddings
from . import metrics
from . import datasets

__all__ = [
"functional",
"structures",
"embeddings",
"metrics",
"datasets",
Expand Down
98 changes: 98 additions & 0 deletions hdc/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Any, List, Optional, Tuple
import torch
from . import functional


class Memory:
"""Associative memory"""

def __init__(self, threshold=0.5):
self.threshold = threshold
self.keys: List[torch.Tensor] = []
self.values: List[Any] = []

def __len__(self) -> int:
"""Returns the number of items in memory"""
return len(self.values)

def add(self, key: torch.Tensor, value: Any) -> None:
"""Adds one (key, value) pair to memory"""
self.keys.append(key)
self.values.append(value)

def _get_index(self, key: torch.Tensor) -> int:
key_stack = torch.stack(self.keys, dim=0)
sim = functional.cosine_similarity(key, key_stack)
value, index = torch.max(sim, 0)

if value.item() < self.threshold:
raise IndexError()

return index

def __getitem__(self, key: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Get the (key, value) pair with an approximate key"""
index = self._get_index(key)
return self.keys[index], self.values[index]

def __setitem__(self, key: torch.Tensor, value: Any) -> None:
"""Set the value of an (key, value) pair with an approximate key"""
index = self._get_index(key)
self.values[index] = value

def __delitem__(self, key: torch.Tensor) -> None:
"""Delete the (key, value) pair with an approximate key"""
index = self._get_index(key)
del self.keys[index]
del self.values[index]


class Set:
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
self.cardinality = 0
self.threshold = threshold
dtype = dtype if dtype is not None else torch.get_default_dtype()
self.value = torch.zeros(dimensions, dtype=dtype, device=device)

def add(self, input: torch.Tensor) -> None:
if input in self:
return

self.value = functional.bundle(self.value, input)
self.cardinality -= 1

def remove(self, input: torch.Tensor) -> None:
if input not in self:
return

self.value = functional.bundle(self.value, -input)
self.cardinality += 1

def __contains__(self, input: torch.Tensor):
sim = functional.cosine_similarity(input, self.values.unsqueeze(0))
return sim.item() > self.threshold

def __len__(self) -> int:
return self.cardinality


class Sequence:
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
self.length = 0
self.threshold = threshold
dtype = dtype if dtype is not None else torch.get_default_dtype()
self.value = torch.zeros(dimensions, dtype=dtype, device=device)

def append(self, input: torch.Tensor) -> None:
rotated_input = functional.permute(input, shifts=self.len)
self.value = functional.bundle(self.value, rotated_input)

def pop(self, index: Optional[int] = None) -> Optional[torch.Tensor]:
raise NotImplementedError()

def __getitem__(self, index: int) -> torch.Tensor:
rotated_value = functional.permute(self.value, shifts=-index)
return rotated_value

def __len__(self) -> int:
return self.length