Skip to content

Commit 4217d89

Browse files
mikeheddespverges
andauthored
Add basic data structures (#4)
* WIP: Add data structures * histogram, graph and ngram * Ngram, graph and sequence functions * Tree * Pull request small fixes * Multiset * Finite state automata * Fixing classmethod * Update sturctures docs * Fix pathspec error * Format Python code * Fix functional import * Format Python code Co-authored-by: verges <pverges8@gmail.com> Co-authored-by: formatting <mikeheddes@users.noreply.github.com>
1 parent 5b287b4 commit 4217d89

File tree

8 files changed

+237
-4
lines changed

8 files changed

+237
-4
lines changed

.github/workflows/format.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
git config --global user.name 'formatting'
3434
git config --global user.email 'mikeheddes@users.noreply.github.com'
3535
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY
36+
git fetch
3637
git checkout $GITHUB_HEAD_REF
3738
git commit -am "Format Python code"
3839
git push

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ numpy
66
sphinx
77
sphinx-rtd-theme
88
flake8
9-
pytest
9+
pytest
10+
black

docs/_templates/class.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
{{ name | underline}}
77

88
.. autoclass:: {{ name }}
9-
:members:
9+
:members:
10+
:special-members:

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Welcome to the Torchhd documentation!
99

1010
functional
1111
embeddings
12+
structures
1213
datasets
1314

1415

docs/structures.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
Structures
2+
==================
3+
4+
.. currentmodule:: torchhd.structures
5+
6+
.. autosummary::
7+
:toctree: generated/
8+
:template: class.rst
9+
10+
Memory
11+
Multiset
12+
Sequence
13+
Graph
14+
Tree
15+
FiniteStateAutomata

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torchhd.functional as functional
22
import torchhd.embeddings as embeddings
3+
import torchhd.structures as structures
34
import torchhd.datasets as datasets
45

56
__all__ = [
67
"functional",
78
"embeddings",
9+
"structures",
810
"datasets",
911
]

torchhd/functional.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"circular_hv",
1313
"bind",
1414
"bundle",
15-
"batch_bundle",
15+
"multiset",
1616
"permute",
1717
"hard_quantize",
1818
"soft_quantize",
@@ -333,7 +333,7 @@ def bundle(input: torch.Tensor, other: torch.Tensor, *, out=None) -> torch.Tenso
333333
return torch.add(input, other, out=out)
334334

335335

336-
def batch_bundle(
336+
def multiset(
337337
input: torch.Tensor,
338338
*,
339339
dim=-2,
@@ -449,6 +449,21 @@ def hamming_similarity(input: torch.Tensor, others: torch.Tensor) -> torch.Tenso
449449
return torch.sum(input == others, dim=-1, dtype=input.dtype)
450450

451451

452+
def ngrams(input: torch.Tensor, n=3):
453+
for i in range(0, n):
454+
if i == (n - 1):
455+
last_sample = None
456+
else:
457+
last_sample = -(n - i - 1)
458+
sample = permute(input[:, i:last_sample], shifts=n - i - 1)
459+
if n is None:
460+
n_gram = sample
461+
else:
462+
n_gram = bind(n, sample)
463+
464+
return multiset(n_gram)
465+
466+
452467
def map_range(
453468
input: torch.Tensor,
454469
in_min: float,

torchhd/structures.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from typing import Any, List, Optional, Tuple
2+
import torch
3+
4+
import torchhd.functional as functional
5+
6+
7+
class Memory:
8+
"""Associative memory"""
9+
10+
def __init__(self, threshold=0.5):
11+
self.threshold = threshold
12+
self.keys: List[torch.Tensor] = []
13+
self.values: List[Any] = []
14+
15+
def __len__(self) -> int:
16+
"""Returns the number of items in memory"""
17+
return len(self.values)
18+
19+
def add(self, key: torch.Tensor, value: Any) -> None:
20+
"""Adds one (key, value) pair to memory"""
21+
self.keys.append(key)
22+
self.values.append(value)
23+
24+
def _get_index(self, key: torch.Tensor) -> int:
25+
key_stack = torch.stack(self.keys, dim=0)
26+
sim = functional.cosine_similarity(key, key_stack)
27+
value, index = torch.max(sim, 0)
28+
29+
if value.item() < self.threshold:
30+
raise IndexError()
31+
32+
return index
33+
34+
def __getitem__(self, key: torch.Tensor) -> Tuple[torch.Tensor, Any]:
35+
"""Get the (key, value) pair with an approximate key"""
36+
index = self._get_index(key)
37+
return self.keys[index], self.values[index]
38+
39+
def __setitem__(self, key: torch.Tensor, value: Any) -> None:
40+
"""Set the value of an (key, value) pair with an approximate key"""
41+
index = self._get_index(key)
42+
self.values[index] = value
43+
44+
def __delitem__(self, key: torch.Tensor) -> None:
45+
"""Delete the (key, value) pair with an approximate key"""
46+
index = self._get_index(key)
47+
del self.keys[index]
48+
del self.values[index]
49+
50+
51+
class Multiset:
52+
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
53+
self.threshold = threshold
54+
self.cardinality = 0
55+
dtype = dtype if dtype is not None else torch.get_default_dtype()
56+
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
57+
58+
def add(self, input: torch.Tensor) -> None:
59+
self.value = functional.bundle(self.value, input)
60+
self.cardinality += 1
61+
62+
def remove(self, input: torch.Tensor) -> None:
63+
if input not in self:
64+
return
65+
self.value = functional.bundle(self.value, -input)
66+
self.cardinality -= 1
67+
68+
def __contains__(self, input: torch.Tensor):
69+
sim = functional.cosine_similarity(input, self.values.unsqueeze(0))
70+
return sim.item() > self.threshold
71+
72+
def __len__(self) -> int:
73+
return self.cardinality
74+
75+
@classmethod
76+
def from_ngrams(cls, input: torch.Tensor, n=3, threshold=0.5):
77+
instance = cls(input.size(-1), threshold, input.device, input.dtype)
78+
instance.value = functional.ngrams(input, n)
79+
return instance
80+
81+
@classmethod
82+
def from_tensors(cls, input: torch.Tensor, dim=-2, threshold=0.5):
83+
instance = cls(input.size(-1), threshold, input.device, input.dtype)
84+
instance.value = functional.multiset(input=input, dim=dim)
85+
return instance
86+
87+
88+
class Sequence:
89+
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
90+
self.length = 0
91+
self.threshold = threshold
92+
dtype = dtype if dtype is not None else torch.get_default_dtype()
93+
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
94+
95+
def append(self, input: torch.Tensor) -> None:
96+
rotated_value = functional.permute(self.value, shifts=1)
97+
self.value = functional.bundle(input, rotated_value)
98+
99+
def appendleft(self, input: torch.Tensor) -> None:
100+
rotated_input = functional.permute(input, shifts=len(self))
101+
self.value = functional.bundle(self.value, rotated_input)
102+
103+
def pop(self, input: torch.Tensor) -> Optional[torch.Tensor]:
104+
self.value = functional.bundle(self.value, -input)
105+
self.value = functional.permute(self.value, shifts=-1)
106+
self.length -= 1
107+
108+
def popleft(self, input: torch.Tensor) -> None:
109+
rotated_input = functional.permute(input, shifts=len(self) + 1)
110+
self.value = functional.bundle(self.value, -rotated_input)
111+
self.length -= 1
112+
113+
def __getitem__(self, index: int) -> torch.Tensor:
114+
rotated_value = functional.permute(self.value, shifts=-index)
115+
return rotated_value
116+
117+
def __len__(self) -> int:
118+
return self.length
119+
120+
121+
class Graph:
122+
def __init__(
123+
self, dimensions, threshold=0.5, directed=False, device=None, dtype=None
124+
):
125+
self.length = 0
126+
self.threshold = threshold
127+
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
128+
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
129+
self.directed = directed
130+
131+
def add_edge(self, node1: torch.Tensor, node2: torch.Tensor):
132+
if self.directed:
133+
edge = functional.bind(node1, node2)
134+
else:
135+
edge = functional.bind(node1, functional.permute(node2))
136+
self.value = functional.bundle(self.value, edge)
137+
138+
def edge_exists(self, node1: torch.Tensor, node2: torch.Tensor):
139+
if self.directed:
140+
edge = functional.bind(node1, node2)
141+
else:
142+
edge = functional.bind(node1, functional.permute(node2))
143+
return edge in self
144+
145+
def node_neighbours(self, input: torch.Tensor):
146+
return functional.bind(self.value, input)
147+
148+
def __contains__(self, input: torch.Tensor):
149+
sim = functional.cosine_similarity(input, self.value.unsqueeze(0))
150+
return sim.item() > self.threshold
151+
152+
153+
class Tree:
154+
def __init__(self, dimensions, device=None, dtype=None):
155+
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
156+
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
157+
self.l_r = functional.random_hv(2, dimensions)
158+
159+
def add_leaf(self, value, path):
160+
for i in path:
161+
if i == "l":
162+
value = functional.bind(value, self.left)
163+
else:
164+
value = functional.bind(value, self.right)
165+
self.value = functional.bundle(self.value, value)
166+
167+
@property
168+
def left(self):
169+
return self.l_r[0]
170+
171+
@property
172+
def right(self):
173+
return self.l_r[1]
174+
175+
176+
class FiniteStateAutomata:
177+
def __init__(self, dimensions, device=None, dtype=None):
178+
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
179+
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
180+
181+
def add_transition(
182+
self,
183+
token: torch.Tensor,
184+
initial_state: torch.Tensor,
185+
final_state: torch.Tensor,
186+
):
187+
transition_edge = functional.bind(
188+
initial_state, functional.permute(final_state)
189+
)
190+
transition = functional.bind(token, transition_edge)
191+
self.value = functional.bundle(self.value, transition)
192+
193+
def change_state(self, token: torch.Tensor, current_state: torch.Tensor):
194+
# Returns the next state + some noise
195+
next_state = functional.bind(self.value, current_state)
196+
next_state = functional.bind(next_state, token)
197+
return functional.permute(next_state, shifts=-1)

0 commit comments

Comments
 (0)