1
+ class Node :
2
+ def __init__ (self , val , parent ):
3
+ self .val = val
4
+ self .leftChild = None
5
+ self .rightChild = None
6
+ self .parent = parent
7
+ self .hasTraversedLeft = False
8
+ self .hasTraversedRight = False
9
+
10
+ def __str__ (self ):
11
+ return "Node(%s)" % self .val
12
+
13
+ def get (self ):
14
+ return self .val
15
+
16
+ def set (self , val ):
17
+ self .val = val
18
+
19
+ def getChildren (self ):
20
+ children = []
21
+ if (self .leftChild != None ):
22
+ children .append (self .leftChild )
23
+ if (self .rightChild != None ):
24
+ children .append (self .rightChild )
25
+ return children
26
+
27
+ def isLeftChild (self ):
28
+ if self .parent is None : return False
29
+ return (self .val <= self .parent .val )
30
+
31
+ def isRightChild (self ):
32
+ if self .parent is None : return False
33
+ return (self .val > self .parent .val )
34
+
35
+ # Duplicates are kept as left child of currentNode
36
+ class BST :
37
+ def __init__ (self ):
38
+ self .root = None
39
+
40
+ # def __str__(self):
41
+ # ans = []
42
+ # currentNode = self.root
43
+ # while currentNode
44
+
45
+ def setRoot (self , val ):
46
+ self .root = Node (val , None )
47
+
48
+ def insert (self , val ):
49
+ if (self .root is None ):
50
+ self .setRoot (val )
51
+ else :
52
+ self .insertNode (self .root , val )
53
+
54
+ def insertNode (self , currentNode , val ):
55
+ if (val <= currentNode .val ):
56
+ if (currentNode .leftChild ):
57
+ self .insertNode (currentNode .leftChild , val )
58
+ else :
59
+ currentNode .leftChild = Node (val , currentNode )
60
+ elif (val > currentNode .val ):
61
+ if (currentNode .rightChild ):
62
+ self .insertNode (currentNode .rightChild , val )
63
+ else :
64
+ currentNode .rightChild = Node (val , currentNode )
65
+
66
+ def remove (self , val , startNode = None ):
67
+ target = self .find (val , startNode )
68
+ if target is not None :
69
+ if target is self .root :
70
+ if target .leftChild is None and target .rightChild is None :
71
+ self .root = None
72
+ elif target .leftChild is None and target .rightChild is not None :
73
+ self .root = target .rightChild
74
+ elif target .leftChild is not None and target .rightChild is None :
75
+ self .root = target .leftChild
76
+ elif target .leftChild is not None and target .rightChild is not None :
77
+ rightMin = self .min (target .rightChild )
78
+ target .val = rightMin
79
+ self .remove (rightMin , target .rightChild )
80
+ else :
81
+ return False
82
+ return True
83
+ else :
84
+ if target .leftChild is None and target .rightChild is None :
85
+ if target .parent .val > target .val : target .parent .leftChild = None
86
+ elif target .parent .val == target .val : target .parent .rightChild = None
87
+ elif target .parent .val < target .val : target .parent .rightChild = None
88
+ elif target .leftChild is None and target .rightChild is not None :
89
+ if target .parent .val > target .val : target .parent .leftChild = target .rightChild
90
+ elif target .parent .val == target .val : target .parent .rightChild = target .rightChild
91
+ elif target .parent .val < target .val : target .parent .rightChild = target .rightChild
92
+ target .rightChild .parent = target .parent
93
+ elif target .leftChild is not None and target .rightChild is None :
94
+ if target .parent .val > target .val : target .parent .leftChild = target .leftChild
95
+ elif target .parent .val == target .val : target .parent .rightChild = target .leftChild
96
+ elif target .parent .val < target .val : target .parent .rightChild = target .leftChild
97
+ target .leftChild .parent = target .parent
98
+ elif target .rightChild is not None and target .rightChild is not None :
99
+ rightMin = self .min (target .rightChild )
100
+ target .val = rightMin
101
+ self .remove (rightMin , target .rightChild )
102
+ else :
103
+ # Funny edge case
104
+ return False
105
+ return True
106
+ else :
107
+ return False
108
+
109
+
110
+ def find (self , val , startNode = None ):
111
+ # Returns node if found, else None
112
+ return self .findNode (self .root if startNode is None else startNode , val )
113
+
114
+ def getPredecessor (self , val ):
115
+ n = self .find (val )
116
+ if n is not None :
117
+ # Case 1: Node has a left subtree
118
+ if n .leftChild is not None :
119
+ return self .max (n .leftChild , returnVal = False )
120
+ else :
121
+ # Case 2: Node has no left subtree, it is the left child of its parent
122
+ predecessor = n
123
+ while predecessor .isLeftChild ():
124
+ predecessor = predecessor .parent
125
+ # Case 3: Node has no left subtree, it is the right child of its parent
126
+ return predecessor .parent
127
+ else :
128
+ return None
129
+
130
+ # Return smallest Node greater than val
131
+ def getSuccessor (self , val ):
132
+ n = self .find (val )
133
+ if n is not None :
134
+ # Case 1: Node has a right subtree
135
+ if n .rightChild is not None :
136
+ return self .min (n .rightChild , returnVal = False )
137
+ else :
138
+ # Case 2: Node has no right subtree, it is the right child of its parent
139
+ successor = n
140
+ while successor .isRightChild ():
141
+ successor = successor .parent
142
+ # Case 3: Node has no right subtree, it is left child of its parent
143
+ return successor .parent
144
+ else :
145
+ return None
146
+
147
+ def findNode (self , currentNode , val ):
148
+ if (currentNode is None ):
149
+ return None
150
+ elif (val == currentNode .val ):
151
+ return currentNode
152
+ elif (val < currentNode .val ):
153
+ return self .findNode (currentNode .leftChild , val )
154
+ else :
155
+ return self .findNode (currentNode .rightChild , val )
156
+
157
+ def min (self , node = None , returnVal = True ):
158
+ if self .root is None : return None
159
+ currentNode = self .root if node is None else node
160
+ while currentNode .leftChild is not None :
161
+ currentNode = currentNode .leftChild
162
+ return currentNode .val if returnVal else currentNode
163
+
164
+ def max (self , node = None , returnVal = True ):
165
+ if self .root is None : return None
166
+ currentNode = self .root if node is None else node
167
+ while currentNode .rightChild is not None :
168
+ currentNode = currentNode .rightChild
169
+ return currentNode .val if returnVal else currentNode
170
+
171
+ # Iterative in-order traversal of nodes
172
+ def sort (self , reverse = False ):
173
+ # Returns tree in ascending order
174
+ import copy
175
+ currentNode = copy .deepcopy (self .root )
176
+ while True :
177
+ if currentNode .hasTraversedLeft and currentNode .hasTraversedRight and currentNode .parent is None :
178
+ break
179
+ else :
180
+ if not currentNode .hasTraversedLeft :
181
+ currentNode .hasTraversedLeft = True
182
+ if currentNode .leftChild is not None : currentNode = currentNode .leftChild
183
+ else :
184
+ if not currentNode .hasTraversedRight :
185
+ yield currentNode .val
186
+ currentNode .hasTraversedRight = True
187
+ if currentNode .rightChild is not None : currentNode = currentNode .rightChild
188
+ else :
189
+ currentNode = currentNode .parent
190
+
191
+
192
+ def test ():
193
+ data = [10 ,4 ,5 ,2 ,3 ,8 ,9 ,9 ]
194
+ t = BST ()
195
+ for i in data :
196
+ t .insert (i )
197
+
198
+ # Test sort
199
+ sorted = []
200
+ for i in t .sort ():
201
+ sorted .append (i )
202
+ assert (sorted == [2 ,3 ,4 ,5 ,8 ,9 ,9 ,10 ])
203
+
204
+ # Test min max
205
+ assert (t .min () == 2 )
206
+ assert (t .max () == 10 )
207
+ t .insert (11 )
208
+ assert (t .max () == 11 )
209
+ t .insert (- 10 )
210
+ assert (t .min () == - 10 )
211
+
212
+ # Test find
213
+ assert (t .find (10 ).val == 10 )
214
+
215
+ # Test remove
216
+ assert (t .remove (9 ) == True )
217
+ assert (t .remove (10 ) == True )
218
+ assert (t .remove (6 ) == False )
219
+ sorted = []
220
+ for i in t .sort ():
221
+ sorted .append (i )
222
+ assert (sorted == [- 10 , 2 , 3 , 4 , 5 , 8 , 9 , 11 ])
223
+
224
+ # Current state of tree
225
+ # 11
226
+ # /
227
+ # 4
228
+ # / \
229
+ # 2 5
230
+ # / \ \
231
+ #-10 3 8
232
+ # \
233
+ # 9
234
+
235
+ # Test getSuccessor
236
+ assert (t .getSuccessor (4 ).val == 5 )
237
+ assert (t .getSuccessor (2 ).val == 3 )
238
+ assert (t .getSuccessor (3 ).val == 4 )
239
+ assert (t .getSuccessor (9 ).val == 11 )
240
+ assert (t .getSuccessor (- 10 ).val == 2 )
241
+ assert (t .getSuccessor (15 ) == None )
242
+ assert (t .getSuccessor (11 ) == None )
243
+
244
+ # Test getPredecessor
245
+ assert (t .getPredecessor (4 ).val == 3 )
246
+ assert (t .getPredecessor (2 ).val == - 10 )
247
+ assert (t .getPredecessor (3 ).val == 2 )
248
+ assert (t .getPredecessor (5 ).val == 4 )
249
+ assert (t .getPredecessor (9 ).val == 8 )
250
+ assert (t .getPredecessor (- 10 ) == None )
251
+ assert (t .getPredecessor (11 ).val == 9 )
252
+ assert (t .getPredecessor (15 ) == None )
253
+
254
+ if __name__ == "__main__" :
255
+ test ()
0 commit comments