Skip to content

Commit 96b4995

Browse files
committed
Type hint IntegratorLearner (#372)
* Add type-hints to adaptive/learner/integrator_learner.py * Add type-hints to adaptive/tests/algorithm_4.py * Add type-hints to adaptive/learner/integrator_coeffs.py
1 parent 812beb8 commit 96b4995

File tree

3 files changed

+93
-71
lines changed

3 files changed

+93
-71
lines changed

adaptive/learner/integrator_coeffs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Based on an adaptive quadrature algorithm by Pedro Gonnet
2+
from __future__ import annotations
23

34
from collections import defaultdict
45
from fractions import Fraction
@@ -8,7 +9,7 @@
89
import scipy.linalg
910

1011

11-
def legendre(n):
12+
def legendre(n: int) -> list[list[Fraction]]:
1213
"""Return the first n Legendre polynomials.
1314
1415
The polynomials have *standard* normalization, i.e.
@@ -29,7 +30,7 @@ def legendre(n):
2930
return result
3031

3132

32-
def newton(n):
33+
def newton(n: int) -> np.ndarray:
3334
"""Compute the monomial coefficients of the Newton polynomial over the
3435
nodes of the n-point Clenshaw-Curtis quadrature rule.
3536
"""
@@ -86,7 +87,7 @@ def newton(n):
8687
return cf
8788

8889

89-
def scalar_product(a, b):
90+
def scalar_product(a: list[Fraction], b: list[Fraction]) -> Fraction:
9091
"""Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
9192
9293
The args must be sequences of polynomial coefficients. This
@@ -107,7 +108,7 @@ def scalar_product(a, b):
107108
return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
108109

109110

110-
def calc_bdef(ns):
111+
def calc_bdef(ns: tuple[int, int, int, int]) -> list[np.ndarray]:
111112
"""Calculate the decompositions of Newton polynomials (over the nodes
112113
of the n-point Clenshaw-Curtis quadrature rule) in terms of
113114
Legandre polynomials.
@@ -133,7 +134,7 @@ def calc_bdef(ns):
133134
return result
134135

135136

136-
def calc_V(x, n):
137+
def calc_V(x: np.ndarray, n: int) -> np.ndarray:
137138
V = [np.ones(x.shape), x.copy()]
138139
for i in range(2, n):
139140
V.append((2 * i - 1) / i * x * V[-1] - (i - 1) / i * V[-2])

adaptive/learner/integrator_learner.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import defaultdict
66
from math import sqrt
77
from operator import attrgetter
8+
from typing import TYPE_CHECKING, Callable
89

910
import cloudpickle
1011
import numpy as np
@@ -25,7 +26,7 @@
2526
with_pandas = False
2627

2728

28-
def _downdate(c, nans, depth):
29+
def _downdate(c: np.ndarray, nans: list[int], depth: int) -> np.ndarray:
2930
# This is algorithm 5 from the thesis of Pedro Gonnet.
3031
b = coeff.b_def[depth].copy()
3132
m = coeff.ns[depth] - 1
@@ -45,7 +46,7 @@ def _downdate(c, nans, depth):
4546
return c
4647

4748

48-
def _zero_nans(fx):
49+
def _zero_nans(fx: np.ndarray) -> list[int]:
4950
"""Caution: this function modifies fx."""
5051
nans = []
5152
for i in range(len(fx)):
@@ -55,7 +56,7 @@ def _zero_nans(fx):
5556
return nans
5657

5758

58-
def _calc_coeffs(fx, depth):
59+
def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
5960
"""Caution: this function modifies fx."""
6061
nans = _zero_nans(fx)
6162
c_new = coeff.V_inv[depth] @ fx
@@ -135,27 +136,32 @@ class _Interval:
135136
"removed",
136137
]
137138

138-
def __init__(self, a, b, depth, rdepth):
139-
self.children = []
140-
self.data = {}
139+
def __init__(self, a: int | float, b: int | float, depth: int, rdepth: int) -> None:
140+
self.children: list[_Interval] = []
141+
self.data: dict[float, float] = {}
141142
self.a = a
142143
self.b = b
143144
self.depth = depth
144145
self.rdepth = rdepth
145-
self.done_leaves = set()
146-
self.depth_complete = None
146+
self.done_leaves: set[_Interval] = set()
147+
self.depth_complete: int | None = None
147148
self.removed = False
149+
if TYPE_CHECKING:
150+
self.ndiv: int
151+
self.parent: _Interval | None
152+
self.err: float
153+
self.c: np.ndarray
148154

149155
@classmethod
150-
def make_first(cls, a, b, depth=2):
156+
def make_first(cls, a: int, b: int, depth: int = 2) -> _Interval:
151157
ival = _Interval(a, b, depth, rdepth=1)
152158
ival.ndiv = 0
153159
ival.parent = None
154160
ival.err = sys.float_info.max # needed because inf/2 == inf
155161
return ival
156162

157163
@property
158-
def T(self):
164+
def T(self) -> np.ndarray:
159165
"""Get the correct shift matrix.
160166
161167
Should only be called on children of a split interval.
@@ -166,24 +172,24 @@ def T(self):
166172
assert left != right
167173
return coeff.T_left if left else coeff.T_right
168174

169-
def refinement_complete(self, depth):
175+
def refinement_complete(self, depth: int) -> bool:
170176
"""The interval has all the y-values to calculate the intergral."""
171177
if len(self.data) < coeff.ns[depth]:
172178
return False
173179
return all(p in self.data for p in self.points(depth))
174180

175-
def points(self, depth=None):
181+
def points(self, depth: int | None = None) -> np.ndarray:
176182
if depth is None:
177183
depth = self.depth
178184
a = self.a
179185
b = self.b
180186
return (a + b) / 2 + (b - a) * coeff.xi[depth] / 2
181187

182-
def refine(self):
188+
def refine(self) -> _Interval:
183189
self.depth += 1
184190
return self
185191

186-
def split(self):
192+
def split(self) -> list[_Interval]:
187193
points = self.points()
188194
m = points[len(points) // 2]
189195
ivals = [
@@ -198,10 +204,10 @@ def split(self):
198204

199205
return ivals
200206

201-
def calc_igral(self):
207+
def calc_igral(self) -> None:
202208
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
203209

204-
def update_heuristic_err(self, value):
210+
def update_heuristic_err(self, value: float) -> None:
205211
"""Sets the error of an interval using a heuristic (half the error of
206212
the parent) when the actual error cannot be calculated due to its
207213
parents not being finished yet. This error is propagated down to its
@@ -214,7 +220,7 @@ def update_heuristic_err(self, value):
214220
continue
215221
child.update_heuristic_err(value / 2)
216222

217-
def calc_err(self, c_old):
223+
def calc_err(self, c_old: np.ndarray) -> float:
218224
c_new = self.c
219225
c_diff = np.zeros(max(len(c_old), len(c_new)))
220226
c_diff[: len(c_old)] = c_old
@@ -226,9 +232,9 @@ def calc_err(self, c_old):
226232
child.update_heuristic_err(self.err / 2)
227233
return c_diff
228234

229-
def calc_ndiv(self):
235+
def calc_ndiv(self) -> None:
230236
div = self.parent.c00 and self.c00 / self.parent.c00 > 2
231-
self.ndiv += div
237+
self.ndiv += int(div)
232238

233239
if self.ndiv > coeff.ndiv_max and 2 * self.ndiv > self.rdepth:
234240
raise DivergentIntegralError
@@ -237,15 +243,15 @@ def calc_ndiv(self):
237243
for child in self.children:
238244
child.update_ndiv_recursively()
239245

240-
def update_ndiv_recursively(self):
246+
def update_ndiv_recursively(self) -> None:
241247
self.ndiv += 1
242248
if self.ndiv > coeff.ndiv_max and 2 * self.ndiv > self.rdepth:
243249
raise DivergentIntegralError
244250

245251
for child in self.children:
246252
child.update_ndiv_recursively()
247253

248-
def complete_process(self, depth):
254+
def complete_process(self, depth: int) -> tuple[bool, bool] | tuple[bool, np.bool_]:
249255
"""Calculate the integral contribution and error from this interval,
250256
and update the done leaves of all ancestor intervals."""
251257
assert self.depth_complete is None or self.depth_complete == depth - 1
@@ -322,7 +328,7 @@ def complete_process(self, depth):
322328

323329
return force_split, remove
324330

325-
def __repr__(self):
331+
def __repr__(self) -> str:
326332
lst = [
327333
f"(a, b)=({self.a:.5f}, {self.b:.5f})",
328334
f"depth={self.depth}",
@@ -334,7 +340,7 @@ def __repr__(self):
334340

335341

336342
class IntegratorLearner(BaseLearner):
337-
def __init__(self, function, bounds, tol):
343+
def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> None:
338344
"""
339345
Parameters
340346
----------
@@ -368,16 +374,18 @@ def __init__(self, function, bounds, tol):
368374
plot : hv.Scatter
369375
Plots all the points that are evaluated.
370376
"""
371-
self.function = function
377+
self.function = function # type: ignore
372378
self.bounds = bounds
373379
self.tol = tol
374380
self.max_ivals = 1000
375-
self.priority_split = []
381+
self.priority_split: list[_Interval] = []
376382
self.data = {}
377383
self.pending_points = set()
378-
self._stack = []
379-
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
380-
self.ivals = set()
384+
self._stack: list[float] = []
385+
self.x_mapping: dict[float, SortedSet] = defaultdict(
386+
lambda: SortedSet([], key=attrgetter("rdepth"))
387+
)
388+
self.ivals: set[_Interval] = set()
381389
ival = _Interval.make_first(*self.bounds)
382390
self.add_ival(ival)
383391
self.first_ival = ival
@@ -387,10 +395,10 @@ def new(self) -> IntegratorLearner:
387395
return IntegratorLearner(self.function, self.bounds, self.tol)
388396

389397
@property
390-
def approximating_intervals(self):
398+
def approximating_intervals(self) -> set[_Interval]:
391399
return self.first_ival.done_leaves
392400

393-
def tell(self, point, value):
401+
def tell(self, point: float, value: float) -> None:
394402
if point not in self.x_mapping:
395403
raise ValueError(f"Point {point} doesn't belong to any interval")
396404
self.data[point] = value
@@ -426,7 +434,7 @@ def tell(self, point, value):
426434
def tell_pending(self):
427435
pass
428436

429-
def propagate_removed(self, ival):
437+
def propagate_removed(self, ival: _Interval) -> None:
430438
def _propagate_removed_down(ival):
431439
ival.removed = True
432440
self.ivals.discard(ival)
@@ -436,7 +444,7 @@ def _propagate_removed_down(ival):
436444

437445
_propagate_removed_down(ival)
438446

439-
def add_ival(self, ival):
447+
def add_ival(self, ival: _Interval) -> None:
440448
for x in ival.points():
441449
# Update the mappings
442450
self.x_mapping[x].add(ival)
@@ -447,15 +455,15 @@ def add_ival(self, ival):
447455
self._stack.append(x)
448456
self.ivals.add(ival)
449457

450-
def ask(self, n, tell_pending=True):
458+
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[float], list[float]]:
451459
"""Choose points for learners."""
452460
if not tell_pending:
453461
with restore(self):
454462
return self._ask_and_tell_pending(n)
455463
else:
456464
return self._ask_and_tell_pending(n)
457465

458-
def _ask_and_tell_pending(self, n):
466+
def _ask_and_tell_pending(self, n: int) -> tuple[list[float], list[float]]:
459467
points, loss_improvements = self.pop_from_stack(n)
460468
n_left = n - len(points)
461469
while n_left > 0:
@@ -471,7 +479,7 @@ def _ask_and_tell_pending(self, n):
471479

472480
return points, loss_improvements
473481

474-
def pop_from_stack(self, n):
482+
def pop_from_stack(self, n: int) -> tuple[list[float], list[float]]:
475483
points = self._stack[:n]
476484
self._stack = self._stack[n:]
477485
loss_improvements = [
@@ -482,7 +490,7 @@ def pop_from_stack(self, n):
482490
def remove_unfinished(self):
483491
pass
484492

485-
def _fill_stack(self):
493+
def _fill_stack(self) -> list[float]:
486494
# XXX: to-do if all the ivals have err=inf, take the interval
487495
# with the lowest rdepth and no children.
488496
force_split = bool(self.priority_split)
@@ -518,16 +526,16 @@ def _fill_stack(self):
518526
return self._stack
519527

520528
@property
521-
def npoints(self):
529+
def npoints(self) -> int:
522530
"""Number of evaluated points."""
523531
return len(self.data)
524532

525533
@property
526-
def igral(self):
534+
def igral(self) -> float:
527535
return sum(i.igral for i in self.approximating_intervals)
528536

529537
@property
530-
def err(self):
538+
def err(self) -> float:
531539
if self.approximating_intervals:
532540
err = sum(i.err for i in self.approximating_intervals)
533541
if err > sys.float_info.max:

0 commit comments

Comments
 (0)