44from pprint import pformat
55from typing import Iterable , Generic , TypeVar
66
7- from . comparable import Comparable
8- from . utils import (
7+ from comparable import Comparable
8+ from utils import (
99 is_sorted ,
1010 has_duplicates ,
1111)
12- from . exceptions import (
12+ from exceptions import (
1313 IntervalMapUnequalLength ,
1414 IntervalMapMustBeSorted ,
1515 IntervalMapNoDuplicates ,
@@ -66,6 +66,17 @@ def __init__(
6666
6767 self ._lpoints = list (copy .deepcopy (interval_left_points ))
6868 self ._vals = [copy .deepcopy (default_val )] + list (copy .deepcopy (vals ))
69+
70+ for i in range (len (self ._lpoints ) - 1 , - 1 , - 1 ):
71+ if self ._vals [i + 1 ] == self ._vals [i ]:
72+ self .__delete_by_index (i )
73+
74+ def __delete_by_index (self , ind : int ) -> bool :
75+ if ind < len (self ._lpoints ):
76+ del self ._lpoints [ind ]
77+ del self ._vals [ind + 1 ]
78+ return True
79+ return False
6980
7081 def __getitem__ (self , key : ComparableKey ) -> AnyValueType :
7182 return self .get (key )
@@ -101,7 +112,10 @@ def set(self, key: ComparableKey, val: AnyValueType) -> None:
101112 return
102113
103114 if self ._lpoints [ind ] == key :
104- self ._vals [ind + 1 ] = val
115+ if self ._vals [ind ] != val :
116+ self ._vals [ind + 1 ] = val
117+ else :
118+ self .__delete_by_index (ind )
105119 elif self ._vals [ind ] != val :
106120 self ._lpoints .insert (ind , key )
107121 self ._vals .insert (ind + 1 , val )
@@ -124,8 +138,15 @@ def unset(self, key: ComparableKey) -> bool:
124138 if ind >= len (self ._lpoints ):
125139 return False
126140 elif self ._lpoints [ind ] == key :
127- del self ._lpoints [ind ]
128- del self ._vals [ind + 1 ]
141+ self .__delete_by_index (ind )
142+
143+ if (
144+ ind < len (self ._vals ) - 1
145+ and
146+ self ._vals [ind ] == self ._vals [ind + 1 ]
147+ ):
148+ self .__delete_by_index (ind )
149+
129150 return True
130151 return False
131152
@@ -149,15 +170,15 @@ def slice_add(
149170 end_ind = bisect .bisect (self ._lpoints , end )
150171 val = self ._vals [end_ind ]
151172
152- start_ind = bisect .bisect (self ._lpoints , start )
153- self .set (start , self ._vals [start_ind ] + summand )
173+ start_ind = bisect .bisect_left (self ._lpoints , start )
174+ self .set (start , self ._vals [start_ind + 1 ] + summand )
154175
155176 if end is not None :
156177 end_ind = bisect .bisect_left (self ._lpoints , end )
157178 else :
158179 end_ind = len (self ._vals )
159180
160- for ind in range (start_ind + 2 , end_ind + 1 ):
181+ for ind in range (start_ind + 1 , end_ind + 1 ):
161182 self ._vals [ind ] += summand
162183
163184 if end is not None :
@@ -216,10 +237,11 @@ def __neg__(self) -> IntervalMap:
216237
217238 def __str__ (self ) -> str :
218239 return (
219- pformat (self .to_dict ())
240+ pformat (self .to_dict (), sort_dicts = False )
220241 .replace ('(' , '[' )
221242 .replace ('[None' , '(-inf' )
222243 .replace ('None]' , '+inf)' )
244+ .replace ('None)' , '+inf)' )
223245 )
224246
225247 def to_dict (
@@ -238,7 +260,10 @@ def __next__(
238260 ) -> tuple [tuple [ComparableKey , ComparableKey ], AnyValueType ]:
239261 if self .__iter == - 1 :
240262 result = (
241- (None , self ._lpoints [0 ]),
263+ (
264+ None ,
265+ self ._lpoints [0 ] if self .__len > - 1 else None
266+ ),
242267 self ._vals [0 ]
243268 )
244269 elif self .__iter == self .__len :
0 commit comments