@@ -10,9 +10,20 @@ def __init__(self, val, parent):
10
10
def __str__ (self ):
11
11
return "Node(%s)" % self .val
12
12
13
+ def __eq__ (self , other ):
14
+ return (
15
+ self .__class__ == other .__class__ and
16
+ self .val == other .val
17
+ )
18
+
19
+ def __ne__ (self , other ):
20
+ return not self .__eq__ (other )
21
+
13
22
def get (self ):
14
23
return self .val
15
24
25
+ # Note: when a node is used in a BST, a node's val is immutable
26
+ # DO NOT change a node's value using n.set(newVal) or directly by n.val = newVal
16
27
def set (self , val ):
17
28
self .val = val
18
29
@@ -37,6 +48,31 @@ class BST:
37
48
def __init__ (self ):
38
49
self .root = None
39
50
51
+ def __eq__ (self , other ):
52
+ currentSelfNode = self .root
53
+ currentOtherNode = other .root
54
+ if currentSelfNode is not None and currentOtherNode is not None :
55
+ selfStack = [currentSelfNode ]
56
+ otherStack = [currentOtherNode ]
57
+ while selfStack and otherStack :
58
+ currentSelfNode = selfStack .pop ()
59
+ currentOtherNode = otherStack .pop ()
60
+ if currentSelfNode != currentOtherNode :
61
+ return False
62
+ else :
63
+ # Depth-first in-order traversal
64
+ if currentSelfNode .rightChild is not None : selfStack .append (currentSelfNode .rightChild )
65
+ if currentOtherNode .rightChild is not None : otherStack .append (currentOtherNode .rightChild )
66
+ if currentSelfNode .leftChild is not None : selfStack .append (currentSelfNode .leftChild )
67
+ if currentOtherNode .leftChild is not None : otherStack .append (currentOtherNode .leftChild )
68
+ return True
69
+ elif currentSelfNode != currentOtherNode :
70
+ return False
71
+ else :
72
+ # Both roots are None
73
+ return True
74
+
75
+ # TODO: a string represetation of a BST
40
76
# def __str__(self):
41
77
# ans = []
42
78
# currentNode = self.root
@@ -188,6 +224,66 @@ def sort(self, reverse=False):
188
224
else :
189
225
currentNode = currentNode .parent
190
226
227
+ def leftRotate (self , val ):
228
+ # Let x be node you are rotating, y be x.rightChild
229
+ # left-rotate x would place
230
+ # y.leftChild under x.rightChild
231
+ # parent of y = parent of x
232
+ # y.leftChild = x
233
+ # x y
234
+ # / \ / \
235
+ # A y => x C
236
+ # / \ / \
237
+ # B C A B
238
+ x = self .find (val )
239
+ if x is not None :
240
+ y = x .rightChild
241
+ if y is None : return # invalid operation
242
+ parentOfX = x .parent
243
+ # Change parent of x
244
+ if parentOfX is not None :
245
+ y .parent = parentOfX
246
+ if x .isLeftChild ():
247
+ parentOfX .leftChild = y
248
+ else :
249
+ parentOfX .rightChild = y
250
+
251
+ # Change right child of x
252
+ x .rightChild = y .leftChild
253
+ if y .leftChild is not None : y .leftChild .parent = x
254
+
255
+ # Change left child of y
256
+ y .leftChild = x
257
+ x .parent = y
258
+
259
+ def rightRotate (self , val ):
260
+ # Opposite of leftRotate
261
+ # y x
262
+ # / \ / \
263
+ # x C => A y
264
+ # / \ / \
265
+ # A B B C
266
+ #
267
+ y = self .find (val )
268
+ if y is not None :
269
+ x = y .leftChild
270
+ if x is None : return # invalid operation
271
+ parentOfY = y .parent
272
+ # Change parent of y
273
+ if parentOfY is not None :
274
+ x .parent = parentOfY
275
+ if y .isLeftChild ():
276
+ parentOfY .leftChild = x
277
+ else :
278
+ parentOfY .rightChild = x
279
+
280
+ # Change left child of y
281
+ y .leftChild = x .rightChild
282
+ if x .rightChild is not None : x .rightChild .parent = y
283
+
284
+ # Change rightChild of x
285
+ x .rightChild = y
286
+ y .parent = x
191
287
192
288
def test ():
193
289
data = [10 ,4 ,5 ,2 ,3 ,8 ,9 ,9 ]
@@ -251,5 +347,44 @@ def test():
251
347
assert (t .getPredecessor (11 ).val == 9 )
252
348
assert (t .getPredecessor (15 ) == None )
253
349
350
+ # Test leftRotate
351
+ t .leftRotate (4 )
352
+ assert (t .find (11 ).leftChild .val == 5 )
353
+ assert (t .find (5 ).parent .val == 11 )
354
+ assert (t .find (4 ).parent .val == 5 )
355
+ assert (t .find (5 ).leftChild .val == 4 )
356
+ assert (t .find (4 ).leftChild .val == 2 )
357
+ assert (t .find (2 ).parent .val == 4 )
358
+ assert (t .find (4 ).rightChild == None )
359
+ assert (t .find (5 ).rightChild .val == 8 )
360
+ assert (t .find (8 ).parent .val == 5 )
361
+
362
+ # Test rightRotate
363
+ t .rightRotate (5 )
364
+ assert (t .find (11 ).leftChild .val == 4 )
365
+ assert (t .find (4 ).parent .val == 11 )
366
+ assert (t .find (4 ).leftChild .val == 2 )
367
+ assert (t .find (2 ).parent .val == 4 )
368
+ assert (t .find (4 ).rightChild .val == 5 )
369
+ assert (t .find (5 ).parent .val == 4 )
370
+ assert (t .find (5 ).leftChild == None )
371
+ assert (t .find (5 ).rightChild .val == 8 )
372
+ assert (t .find (8 ).parent .val == 5 )
373
+
374
+ # Test that leftRotate and rightRotate are inverse of each other
375
+ import copy
376
+ newT = copy .deepcopy (t )
377
+ assert (newT == t )
378
+ newT .leftRotate (4 )
379
+ newT .rightRotate (5 )
380
+ assert (t == newT )
381
+ newT .leftRotate (5 )
382
+ newT .rightRotate (8 )
383
+ assert (t == newT )
384
+
385
+ # Test that invalid operations do not go through
386
+ newT .leftRotate (11 )
387
+ assert (t == newT )
388
+
254
389
if __name__ == "__main__" :
255
390
test ()
0 commit comments