5
5
from collections import defaultdict
6
6
from math import sqrt
7
7
from operator import attrgetter
8
+ from typing import TYPE_CHECKING , Callable
8
9
9
10
import cloudpickle
10
11
import numpy as np
25
26
with_pandas = False
26
27
27
28
28
- def _downdate (c , nans , depth ) :
29
+ def _downdate (c : np . ndarray , nans : list [ int ] , depth : int ) -> np . ndarray :
29
30
# This is algorithm 5 from the thesis of Pedro Gonnet.
30
31
b = coeff .b_def [depth ].copy ()
31
32
m = coeff .ns [depth ] - 1
@@ -45,7 +46,7 @@ def _downdate(c, nans, depth):
45
46
return c
46
47
47
48
48
- def _zero_nans (fx ) :
49
+ def _zero_nans (fx : np . ndarray ) -> list [ int ] :
49
50
"""Caution: this function modifies fx."""
50
51
nans = []
51
52
for i in range (len (fx )):
@@ -55,7 +56,7 @@ def _zero_nans(fx):
55
56
return nans
56
57
57
58
58
- def _calc_coeffs (fx , depth ) :
59
+ def _calc_coeffs (fx : np . ndarray , depth : int ) -> np . ndarray :
59
60
"""Caution: this function modifies fx."""
60
61
nans = _zero_nans (fx )
61
62
c_new = coeff .V_inv [depth ] @ fx
@@ -135,27 +136,32 @@ class _Interval:
135
136
"removed" ,
136
137
]
137
138
138
- def __init__ (self , a , b , depth , rdepth ) :
139
- self .children = []
140
- self .data = {}
139
+ def __init__ (self , a : int | float , b : int | float , depth : int , rdepth : int ) -> None :
140
+ self .children : list [ _Interval ] = []
141
+ self .data : dict [ float , float ] = {}
141
142
self .a = a
142
143
self .b = b
143
144
self .depth = depth
144
145
self .rdepth = rdepth
145
- self .done_leaves = set ()
146
- self .depth_complete = None
146
+ self .done_leaves : set [ _Interval ] = set ()
147
+ self .depth_complete : int | None = None
147
148
self .removed = False
149
+ if TYPE_CHECKING :
150
+ self .ndiv : int
151
+ self .parent : _Interval | None
152
+ self .err : float
153
+ self .c : np .ndarray
148
154
149
155
@classmethod
150
- def make_first (cls , a , b , depth = 2 ) :
156
+ def make_first (cls , a : int , b : int , depth : int = 2 ) -> _Interval :
151
157
ival = _Interval (a , b , depth , rdepth = 1 )
152
158
ival .ndiv = 0
153
159
ival .parent = None
154
160
ival .err = sys .float_info .max # needed because inf/2 == inf
155
161
return ival
156
162
157
163
@property
158
- def T (self ):
164
+ def T (self ) -> np . ndarray :
159
165
"""Get the correct shift matrix.
160
166
161
167
Should only be called on children of a split interval.
@@ -166,24 +172,24 @@ def T(self):
166
172
assert left != right
167
173
return coeff .T_left if left else coeff .T_right
168
174
169
- def refinement_complete (self , depth ) :
175
+ def refinement_complete (self , depth : int ) -> bool :
170
176
"""The interval has all the y-values to calculate the intergral."""
171
177
if len (self .data ) < coeff .ns [depth ]:
172
178
return False
173
179
return all (p in self .data for p in self .points (depth ))
174
180
175
- def points (self , depth = None ):
181
+ def points (self , depth : int | None = None ) -> np . ndarray :
176
182
if depth is None :
177
183
depth = self .depth
178
184
a = self .a
179
185
b = self .b
180
186
return (a + b ) / 2 + (b - a ) * coeff .xi [depth ] / 2
181
187
182
- def refine (self ):
188
+ def refine (self ) -> _Interval :
183
189
self .depth += 1
184
190
return self
185
191
186
- def split (self ):
192
+ def split (self ) -> list [ _Interval ] :
187
193
points = self .points ()
188
194
m = points [len (points ) // 2 ]
189
195
ivals = [
@@ -198,10 +204,10 @@ def split(self):
198
204
199
205
return ivals
200
206
201
- def calc_igral (self ):
207
+ def calc_igral (self ) -> None :
202
208
self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
203
209
204
- def update_heuristic_err (self , value ) :
210
+ def update_heuristic_err (self , value : float ) -> None :
205
211
"""Sets the error of an interval using a heuristic (half the error of
206
212
the parent) when the actual error cannot be calculated due to its
207
213
parents not being finished yet. This error is propagated down to its
@@ -214,7 +220,7 @@ def update_heuristic_err(self, value):
214
220
continue
215
221
child .update_heuristic_err (value / 2 )
216
222
217
- def calc_err (self , c_old ) :
223
+ def calc_err (self , c_old : np . ndarray ) -> float :
218
224
c_new = self .c
219
225
c_diff = np .zeros (max (len (c_old ), len (c_new )))
220
226
c_diff [: len (c_old )] = c_old
@@ -226,9 +232,9 @@ def calc_err(self, c_old):
226
232
child .update_heuristic_err (self .err / 2 )
227
233
return c_diff
228
234
229
- def calc_ndiv (self ):
235
+ def calc_ndiv (self ) -> None :
230
236
div = self .parent .c00 and self .c00 / self .parent .c00 > 2
231
- self .ndiv += div
237
+ self .ndiv += int ( div )
232
238
233
239
if self .ndiv > coeff .ndiv_max and 2 * self .ndiv > self .rdepth :
234
240
raise DivergentIntegralError
@@ -237,15 +243,15 @@ def calc_ndiv(self):
237
243
for child in self .children :
238
244
child .update_ndiv_recursively ()
239
245
240
- def update_ndiv_recursively (self ):
246
+ def update_ndiv_recursively (self ) -> None :
241
247
self .ndiv += 1
242
248
if self .ndiv > coeff .ndiv_max and 2 * self .ndiv > self .rdepth :
243
249
raise DivergentIntegralError
244
250
245
251
for child in self .children :
246
252
child .update_ndiv_recursively ()
247
253
248
- def complete_process (self , depth ) :
254
+ def complete_process (self , depth : int ) -> tuple [ bool , bool ] | tuple [ bool , np . bool_ ] :
249
255
"""Calculate the integral contribution and error from this interval,
250
256
and update the done leaves of all ancestor intervals."""
251
257
assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -322,7 +328,7 @@ def complete_process(self, depth):
322
328
323
329
return force_split , remove
324
330
325
- def __repr__ (self ):
331
+ def __repr__ (self ) -> str :
326
332
lst = [
327
333
f"(a, b)=({ self .a :.5f} , { self .b :.5f} )" ,
328
334
f"depth={ self .depth } " ,
@@ -334,7 +340,7 @@ def __repr__(self):
334
340
335
341
336
342
class IntegratorLearner (BaseLearner ):
337
- def __init__ (self , function , bounds , tol ) :
343
+ def __init__ (self , function : Callable , bounds : tuple [ int , int ], tol : float ) -> None :
338
344
"""
339
345
Parameters
340
346
----------
@@ -368,16 +374,18 @@ def __init__(self, function, bounds, tol):
368
374
plot : hv.Scatter
369
375
Plots all the points that are evaluated.
370
376
"""
371
- self .function = function
377
+ self .function = function # type: ignore
372
378
self .bounds = bounds
373
379
self .tol = tol
374
380
self .max_ivals = 1000
375
- self .priority_split = []
381
+ self .priority_split : list [ _Interval ] = []
376
382
self .data = {}
377
383
self .pending_points = set ()
378
- self ._stack = []
379
- self .x_mapping = defaultdict (lambda : SortedSet ([], key = attrgetter ("rdepth" )))
380
- self .ivals = set ()
384
+ self ._stack : list [float ] = []
385
+ self .x_mapping : dict [float , SortedSet ] = defaultdict (
386
+ lambda : SortedSet ([], key = attrgetter ("rdepth" ))
387
+ )
388
+ self .ivals : set [_Interval ] = set ()
381
389
ival = _Interval .make_first (* self .bounds )
382
390
self .add_ival (ival )
383
391
self .first_ival = ival
@@ -387,10 +395,10 @@ def new(self) -> IntegratorLearner:
387
395
return IntegratorLearner (self .function , self .bounds , self .tol )
388
396
389
397
@property
390
- def approximating_intervals (self ):
398
+ def approximating_intervals (self ) -> set [ _Interval ] :
391
399
return self .first_ival .done_leaves
392
400
393
- def tell (self , point , value ) :
401
+ def tell (self , point : float , value : float ) -> None :
394
402
if point not in self .x_mapping :
395
403
raise ValueError (f"Point { point } doesn't belong to any interval" )
396
404
self .data [point ] = value
@@ -426,7 +434,7 @@ def tell(self, point, value):
426
434
def tell_pending (self ):
427
435
pass
428
436
429
- def propagate_removed (self , ival ) :
437
+ def propagate_removed (self , ival : _Interval ) -> None :
430
438
def _propagate_removed_down (ival ):
431
439
ival .removed = True
432
440
self .ivals .discard (ival )
@@ -436,7 +444,7 @@ def _propagate_removed_down(ival):
436
444
437
445
_propagate_removed_down (ival )
438
446
439
- def add_ival (self , ival ) :
447
+ def add_ival (self , ival : _Interval ) -> None :
440
448
for x in ival .points ():
441
449
# Update the mappings
442
450
self .x_mapping [x ].add (ival )
@@ -447,15 +455,15 @@ def add_ival(self, ival):
447
455
self ._stack .append (x )
448
456
self .ivals .add (ival )
449
457
450
- def ask (self , n , tell_pending = True ):
458
+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [ list [ float ], list [ float ]] :
451
459
"""Choose points for learners."""
452
460
if not tell_pending :
453
461
with restore (self ):
454
462
return self ._ask_and_tell_pending (n )
455
463
else :
456
464
return self ._ask_and_tell_pending (n )
457
465
458
- def _ask_and_tell_pending (self , n ) :
466
+ def _ask_and_tell_pending (self , n : int ) -> tuple [ list [ float ], list [ float ]] :
459
467
points , loss_improvements = self .pop_from_stack (n )
460
468
n_left = n - len (points )
461
469
while n_left > 0 :
@@ -471,7 +479,7 @@ def _ask_and_tell_pending(self, n):
471
479
472
480
return points , loss_improvements
473
481
474
- def pop_from_stack (self , n ) :
482
+ def pop_from_stack (self , n : int ) -> tuple [ list [ float ], list [ float ]] :
475
483
points = self ._stack [:n ]
476
484
self ._stack = self ._stack [n :]
477
485
loss_improvements = [
@@ -482,7 +490,7 @@ def pop_from_stack(self, n):
482
490
def remove_unfinished (self ):
483
491
pass
484
492
485
- def _fill_stack (self ):
493
+ def _fill_stack (self ) -> list [ float ] :
486
494
# XXX: to-do if all the ivals have err=inf, take the interval
487
495
# with the lowest rdepth and no children.
488
496
force_split = bool (self .priority_split )
@@ -518,16 +526,16 @@ def _fill_stack(self):
518
526
return self ._stack
519
527
520
528
@property
521
- def npoints (self ):
529
+ def npoints (self ) -> int :
522
530
"""Number of evaluated points."""
523
531
return len (self .data )
524
532
525
533
@property
526
- def igral (self ):
534
+ def igral (self ) -> float :
527
535
return sum (i .igral for i in self .approximating_intervals )
528
536
529
537
@property
530
- def err (self ):
538
+ def err (self ) -> float :
531
539
if self .approximating_intervals :
532
540
err = sum (i .err for i in self .approximating_intervals )
533
541
if err > sys .float_info .max :
0 commit comments