-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathimplementation.py
200 lines (160 loc) · 6 KB
/
implementation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Module with Trie implementations."""
from typing import Dict, List, Set, Tuple
import pydot
from typing_extensions import Self
from data_structures.stack.implementation import ArrayBasedStack
class TrieNode:
"""Class that represents a node in the Trie."""
def __init__(self, alphabet: List[str], symbol: str) -> None:
self.alphabet = alphabet
self.symbol = symbol
self.final = False
self.edges: Dict[str, TrieNode] = {}
self.cnt = 0
def add_edge(self, symbol: str) -> Self:
"""Adds an edge between the node and a new node
with the symbol supplied.
Args:
symbol (str): The symbol of the new node.
Returns:
TrieNode: The new node.
"""
self.edges[symbol] = TrieNode(self.alphabet, symbol)
return self.edges[symbol]
def get_node(self, symbol: str) -> Self:
"""Returns the adjacent node with the symbol supplied.
Args:
symbol (str): The symbol to look for the node.
Returns:
TrieNode: The adjacent node. None in case it doesn't exist.
"""
return self.edges.get(symbol, None)
class Trie:
"""Class that represents a Trie."""
def __init__(self, alphabet: List[str]) -> None:
self.alphabet = alphabet
self.root = TrieNode(alphabet, "^")
def insert(self, word: str, times: int = 1) -> None:
"""Inserts a word in the Trie as many times as
the `times` argument.
Args:
word (str): The word to be inserted.
times (int, optional): How many occurrences to insert. Defaults to 1.
"""
cur_node = self.root
for symbol in word:
next_node = cur_node.get_node(symbol)
if next_node is None:
next_node = cur_node.add_edge(symbol)
cur_node = next_node
cur_node.cnt += times
cur_node.final = True
def remove(self, word: str, times: int = 1) -> None:
"""Removes a number of occurrences of a word.
Args:
word (str): The word to remove.
times (int, optional): How many occurrences to remove. Defaults to 1.
"""
trace = ArrayBasedStack()
cur_node = self.root
for symbol in word:
next_node = cur_node.get_node(symbol)
if next_node is None:
return
trace.push([cur_node, symbol])
cur_node = next_node
if cur_node.final:
cur_node.cnt = max(0, cur_node.cnt - times)
if cur_node.cnt == 0:
cur_node.final = False
if len(cur_node.edges) == 0:
self._clean_links(trace)
def _clean_links(self, trace: ArrayBasedStack) -> None:
"""Removes the edges that exclusively lead to a word that was removed.
Args:
trace (ArrayBasedStack): the trace of nodes on the path representing the deleted word in the Trie
"""
while not trace.is_empty():
parent_node, symbol = trace.pop()
parent_node.edges.pop(symbol)
if len(parent_node.edges) != 0:
break
def remove_all(self, word: str) -> None:
"""Removes all occurrences of a word.
Args:
word (str): The word to remove.
"""
times = self.count_insertions(word)
self.remove(word, times)
def search(self, word: str) -> bool:
"""Checks if the word supplied is present in the trie.
Args:
word (str): The word to look for.
Returns:
bool: True if the word is found, False otherwise.
"""
return self.count_insertions(word) > 0
def search_prefix(self, word: str) -> bool:
"""Checks if the word supplied is a prefix
of some word in the Trie.
Args:
word (str): The word to check for prefix.
Returns:
bool: True if "word" is prefix of some other word in the Trie, False otherwise.
"""
cur_node = self.root
for symbol in word:
next_node = cur_node.get_node(symbol)
if next_node is None:
return False
cur_node = next_node
return True
def count_insertions(self, word: str) -> int:
"""Counts the number of times the word supplied
has been inserted in the Trie.
Args:
word (str): The word to count insertions.
Returns:
int: The number of times `word` has been inserted in the Trie.
"""
cur_node = self.root
for symbol in word:
next_node = cur_node.get_node(symbol)
if next_node is None:
return 0
cur_node = next_node
return cur_node.cnt
def get_visual_representation(self) -> pydot.Dot:
"""Gets a visual representation of the trie seen
as a graph.
Returns:
pydot.Dot: A dot object describing the trie as a graph.
"""
edges = self._get_edges(self.root)
nodes = self._get_nodes(edges)
graph = pydot.Dot("trie", graph_type="digraph")
for node in nodes:
graph.add_node(
pydot.Node(
name=str(node),
shape="doublecircle" if node.final else "circle",
label=f"{node.symbol}: {node.cnt}" if node.final else node.symbol,
)
)
for edge in edges:
graph.add_edge(pydot.Edge(str(edge[0]), str(edge[1])))
return graph
def _get_edges(self, node: TrieNode):
edges: Set[Tuple[TrieNode, TrieNode]] = set()
for symbol in self.alphabet:
next_node = node.get_node(symbol)
if next_node is not None:
edges.add((node, next_node))
edges |= self._get_edges(next_node)
return edges
def _get_nodes(self, edges):
nodes = set()
for edge in edges:
nodes.add(edge[0])
nodes.add(edge[1])
return nodes