Skip to content

Commit ed16736

Browse files
committed
first commit
0 parents  commit ed16736

File tree

3 files changed

+320
-0
lines changed

3 files changed

+320
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# python-data-structures

bst.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
class Node:
2+
def __init__(self, val, parent):
3+
self.val = val
4+
self.leftChild = None
5+
self.rightChild = None
6+
self.parent = parent
7+
self.hasTraversedLeft = False
8+
self.hasTraversedRight = False
9+
10+
def __str__(self):
11+
return "Node(%s)" % self.val
12+
13+
def get(self):
14+
return self.val
15+
16+
def set(self, val):
17+
self.val = val
18+
19+
def getChildren(self):
20+
children = []
21+
if (self.leftChild != None):
22+
children.append(self.leftChild)
23+
if (self.rightChild != None):
24+
children.append(self.rightChild)
25+
return children
26+
27+
def isLeftChild(self):
28+
if self.parent is None: return False
29+
return (self.val <= self.parent.val)
30+
31+
def isRightChild(self):
32+
if self.parent is None: return False
33+
return (self.val > self.parent.val)
34+
35+
# Duplicates are kept as left child of currentNode
36+
class BST:
37+
def __init__(self):
38+
self.root = None
39+
40+
# def __str__(self):
41+
# ans = []
42+
# currentNode = self.root
43+
# while currentNode
44+
45+
def setRoot(self, val):
46+
self.root = Node(val, None)
47+
48+
def insert(self, val):
49+
if(self.root is None):
50+
self.setRoot(val)
51+
else:
52+
self.insertNode(self.root, val)
53+
54+
def insertNode(self, currentNode, val):
55+
if (val <= currentNode.val):
56+
if (currentNode.leftChild):
57+
self.insertNode(currentNode.leftChild, val)
58+
else:
59+
currentNode.leftChild = Node(val, currentNode)
60+
elif (val > currentNode.val):
61+
if (currentNode.rightChild):
62+
self.insertNode(currentNode.rightChild, val)
63+
else:
64+
currentNode.rightChild = Node(val, currentNode)
65+
66+
def remove(self, val, startNode=None):
67+
target = self.find(val, startNode)
68+
if target is not None:
69+
if target is self.root:
70+
if target.leftChild is None and target.rightChild is None:
71+
self.root = None
72+
elif target.leftChild is None and target.rightChild is not None:
73+
self.root = target.rightChild
74+
elif target.leftChild is not None and target.rightChild is None:
75+
self.root = target.leftChild
76+
elif target.leftChild is not None and target.rightChild is not None:
77+
rightMin = self.min(target.rightChild)
78+
target.val = rightMin
79+
self.remove(rightMin, target.rightChild)
80+
else:
81+
return False
82+
return True
83+
else:
84+
if target.leftChild is None and target.rightChild is None:
85+
if target.parent.val > target.val: target.parent.leftChild = None
86+
elif target.parent.val == target.val: target.parent.rightChild = None
87+
elif target.parent.val < target.val: target.parent.rightChild = None
88+
elif target.leftChild is None and target.rightChild is not None:
89+
if target.parent.val > target.val: target.parent.leftChild = target.rightChild
90+
elif target.parent.val == target.val: target.parent.rightChild = target.rightChild
91+
elif target.parent.val < target.val: target.parent.rightChild = target.rightChild
92+
target.rightChild.parent = target.parent
93+
elif target.leftChild is not None and target.rightChild is None:
94+
if target.parent.val > target.val: target.parent.leftChild = target.leftChild
95+
elif target.parent.val == target.val: target.parent.rightChild = target.leftChild
96+
elif target.parent.val < target.val: target.parent.rightChild = target.leftChild
97+
target.leftChild.parent = target.parent
98+
elif target.rightChild is not None and target.rightChild is not None:
99+
rightMin = self.min(target.rightChild)
100+
target.val = rightMin
101+
self.remove(rightMin, target.rightChild)
102+
else:
103+
# Funny edge case
104+
return False
105+
return True
106+
else:
107+
return False
108+
109+
110+
def find(self, val, startNode=None):
111+
# Returns node if found, else None
112+
return self.findNode(self.root if startNode is None else startNode, val)
113+
114+
def getPredecessor(self, val):
115+
n = self.find(val)
116+
if n is not None:
117+
# Case 1: Node has a left subtree
118+
if n.leftChild is not None:
119+
return self.max(n.leftChild, returnVal=False)
120+
else:
121+
# Case 2: Node has no left subtree, it is the left child of its parent
122+
predecessor = n
123+
while predecessor.isLeftChild():
124+
predecessor = predecessor.parent
125+
# Case 3: Node has no left subtree, it is the right child of its parent
126+
return predecessor.parent
127+
else:
128+
return None
129+
130+
# Return smallest Node greater than val
131+
def getSuccessor(self, val):
132+
n = self.find(val)
133+
if n is not None:
134+
# Case 1: Node has a right subtree
135+
if n.rightChild is not None:
136+
return self.min(n.rightChild, returnVal=False)
137+
else:
138+
# Case 2: Node has no right subtree, it is the right child of its parent
139+
successor = n
140+
while successor.isRightChild():
141+
successor = successor.parent
142+
# Case 3: Node has no right subtree, it is left child of its parent
143+
return successor.parent
144+
else:
145+
return None
146+
147+
def findNode(self, currentNode, val):
148+
if(currentNode is None):
149+
return None
150+
elif(val == currentNode.val):
151+
return currentNode
152+
elif(val < currentNode.val):
153+
return self.findNode(currentNode.leftChild, val)
154+
else:
155+
return self.findNode(currentNode.rightChild, val)
156+
157+
def min(self, node=None, returnVal=True):
158+
if self.root is None: return None
159+
currentNode = self.root if node is None else node
160+
while currentNode.leftChild is not None:
161+
currentNode = currentNode.leftChild
162+
return currentNode.val if returnVal else currentNode
163+
164+
def max(self, node=None, returnVal=True):
165+
if self.root is None: return None
166+
currentNode = self.root if node is None else node
167+
while currentNode.rightChild is not None:
168+
currentNode = currentNode.rightChild
169+
return currentNode.val if returnVal else currentNode
170+
171+
# Iterative in-order traversal of nodes
172+
def sort(self, reverse=False):
173+
# Returns tree in ascending order
174+
import copy
175+
currentNode = copy.deepcopy(self.root)
176+
while True:
177+
if currentNode.hasTraversedLeft and currentNode.hasTraversedRight and currentNode.parent is None:
178+
break
179+
else:
180+
if not currentNode.hasTraversedLeft:
181+
currentNode.hasTraversedLeft = True
182+
if currentNode.leftChild is not None: currentNode = currentNode.leftChild
183+
else:
184+
if not currentNode.hasTraversedRight:
185+
yield currentNode.val
186+
currentNode.hasTraversedRight = True
187+
if currentNode.rightChild is not None: currentNode = currentNode.rightChild
188+
else:
189+
currentNode = currentNode.parent
190+
191+
192+
def test():
193+
data = [10,4,5,2,3,8,9,9]
194+
t = BST()
195+
for i in data:
196+
t.insert(i)
197+
198+
# Test sort
199+
sorted = []
200+
for i in t.sort():
201+
sorted.append(i)
202+
assert(sorted == [2,3,4,5,8,9,9,10])
203+
204+
# Test min max
205+
assert(t.min() == 2)
206+
assert(t.max() == 10)
207+
t.insert(11)
208+
assert(t.max() == 11)
209+
t.insert(-10)
210+
assert(t.min() == -10)
211+
212+
# Test find
213+
assert(t.find(10).val == 10)
214+
215+
# Test remove
216+
assert(t.remove(9) == True)
217+
assert(t.remove(10) == True)
218+
assert(t.remove(6) == False)
219+
sorted = []
220+
for i in t.sort():
221+
sorted.append(i)
222+
assert(sorted == [-10, 2, 3, 4, 5, 8, 9, 11])
223+
224+
# Current state of tree
225+
# 11
226+
# /
227+
# 4
228+
# / \
229+
# 2 5
230+
# / \ \
231+
#-10 3 8
232+
# \
233+
# 9
234+
235+
# Test getSuccessor
236+
assert(t.getSuccessor(4).val == 5)
237+
assert(t.getSuccessor(2).val == 3)
238+
assert(t.getSuccessor(3).val == 4)
239+
assert(t.getSuccessor(9).val == 11)
240+
assert(t.getSuccessor(-10).val == 2)
241+
assert(t.getSuccessor(15) == None)
242+
assert(t.getSuccessor(11) == None)
243+
244+
# Test getPredecessor
245+
assert(t.getPredecessor(4).val == 3)
246+
assert(t.getPredecessor(2).val == -10)
247+
assert(t.getPredecessor(3).val == 2)
248+
assert(t.getPredecessor(5).val == 4)
249+
assert(t.getPredecessor(9).val == 8)
250+
assert(t.getPredecessor(-10) == None)
251+
assert(t.getPredecessor(11).val == 9)
252+
assert(t.getPredecessor(15) == None)
253+
254+
if __name__ == "__main__":
255+
test()

fenwick_tree.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# A 1-indexed Bit Indexed Tree or Fenwick Tree
2+
class FenwickTree:
3+
def __init__(self, a):
4+
for (idx,v) in enumerate(a):
5+
nextIdx = (idx+1) + lowestSetBit(idx+1) - 1
6+
if nextIdx < len(a):
7+
a[nextIdx] += v
8+
self.tree = a
9+
10+
def __str__(self):
11+
return str(self.tree)
12+
13+
# Find the sum of elements from start of array to end
14+
def prefixSum(self, end):
15+
ans = 0
16+
while end > 0:
17+
ans += self.tree[end-1]
18+
end -= lowestSetBit(end)
19+
return ans
20+
21+
# Find the sum of elements from start to end inclusive
22+
def range(self, start, end):
23+
if start < 1:
24+
raise Exception("Starting index must be at least 1.")
25+
if end < start:
26+
raise Exception("End index must be greater than or equal to start")
27+
return self.prefixSum(end) - self.prefixSum(start-1)
28+
29+
# idx is based on a 1-based array
30+
def update(self, idx, value):
31+
if idx < 1:
32+
raise Exception("Invalid index, must be at least 1")
33+
delta = value - self.tree[idx-1]
34+
while idx <= len(self.tree):
35+
self.tree[idx-1] += delta
36+
idx += lowestSetBit(idx)
37+
38+
def lowestSetBit(intType):
39+
return (intType & -intType)
40+
41+
def test():
42+
import copy
43+
a = [1,2,3,4,5,6,7,8,9,10]
44+
tree = FenwickTree(copy.deepcopy(a))
45+
assert(tree.tree == [1,3,3,10,5,11,7,36,9,19])
46+
# print(tree)
47+
48+
assert(tree.prefixSum(1) == 1)
49+
assert(tree.prefixSum(10) == 55)
50+
assert(tree.prefixSum(0) == 0)
51+
assert(tree.prefixSum(-1) == 0)
52+
53+
assert(tree.range(1,4) == 10)
54+
assert(tree.range(2,5) == 14)
55+
assert(tree.range(5, 10) == 45)
56+
57+
tree.update(1, 10)
58+
assert(tree.tree == [10,12,3,19,5,11,7,45,9,19])
59+
assert(tree.range(1,4) == 19)
60+
assert(tree.range(2,5) == 14)
61+
assert(tree.range(5, 10) == 45)
62+
63+
if __name__ == "__main__":
64+
test()

0 commit comments

Comments
 (0)