Skip to content

Commit 8c71e0b

Browse files
committed
Add type-hints to adaptive/tests/algorithm_4.py
1 parent d4b7ce7 commit 8c71e0b

File tree

1 file changed

+40
-27
lines changed

1 file changed

+40
-27
lines changed

adaptive/tests/algorithm_4.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# Copyright 2017 Christoph Groth
33

44
from collections import defaultdict
5-
from fractions import Fraction as Frac
5+
from fractions import Fraction
6+
from typing import Callable, List, Tuple, Union
67

78
import numpy as np
89
from numpy.testing import assert_allclose
@@ -11,28 +12,28 @@
1112
eps = np.spacing(1)
1213

1314

14-
def legendre(n):
15+
def legendre(n: int) -> List[List[Fraction]]:
1516
"""Return the first n Legendre polynomials.
1617
1718
The polynomials have *standard* normalization, i.e.
1819
int_{-1}^1 dx L_n(x) L_m(x) = delta(m, n) * 2 / (2 * n + 1).
1920
2021
The return value is a list of list of fraction.Fraction instances.
2122
"""
22-
result = [[Frac(1)], [Frac(0), Frac(1)]]
23+
result = [[Fraction(1)], [Fraction(0), Fraction(1)]]
2324
if n <= 2:
2425
return result[:n]
2526
for i in range(2, n):
2627
# Use Bonnet's recursion formula.
27-
new = (i + 1) * [Frac(0)]
28+
new = (i + 1) * [Fraction(0)]
2829
new[1:] = (r * (2 * i - 1) for r in result[-1])
2930
new[:-2] = (n - r * (i - 1) for n, r in zip(new[:-2], result[-2]))
3031
new[:] = (n / i for n in new)
3132
result.append(new)
3233
return result
3334

3435

35-
def newton(n):
36+
def newton(n: int) -> np.ndarray:
3637
"""Compute the monomial coefficients of the Newton polynomial over the
3738
nodes of the n-point Clenshaw-Curtis quadrature rule.
3839
"""
@@ -89,7 +90,7 @@ def newton(n):
8990
return cf
9091

9192

92-
def scalar_product(a, b):
93+
def scalar_product(a: List[Fraction], b: List[Fraction]) -> Fraction:
9394
"""Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
9495
9596
The args must be sequences of polynomial coefficients. This
@@ -110,7 +111,7 @@ def scalar_product(a, b):
110111
return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
111112

112113

113-
def calc_bdef(ns):
114+
def calc_bdef(ns: Tuple[int, int, int, int]) -> List[np.ndarray]:
114115
"""Calculate the decompositions of Newton polynomials (over the nodes
115116
of the n-point Clenshaw-Curtis quadrature rule) in terms of
116117
Legandre polynomials.
@@ -123,7 +124,7 @@ def calc_bdef(ns):
123124
result = []
124125
for n in ns:
125126
poly = []
126-
a = list(map(Frac, newton(n)))
127+
a = list(map(Fraction, newton(n)))
127128
for b in legs[: n + 1]:
128129
igral = scalar_product(a, b)
129130

@@ -145,7 +146,7 @@ def calc_bdef(ns):
145146
b_def = calc_bdef(n)
146147

147148

148-
def calc_V(xi, n):
149+
def calc_V(xi: np.ndarray, n: int) -> np.ndarray:
149150
V = [np.ones(xi.shape), xi.copy()]
150151
for i in range(2, n):
151152
V.append((2 * i - 1) / i * xi * V[-1] - (i - 1) / i * V[-2])
@@ -183,7 +184,7 @@ def calc_V(xi, n):
183184
gamma = np.concatenate([[0, 0], np.sqrt(k[2:] ** 2 / (4 * k[2:] ** 2 - 1))])
184185

185186

186-
def _downdate(c, nans, depth):
187+
def _downdate(c: np.ndarray, nans: List[int], depth: int) -> None:
187188
# This is algorithm 5 from the thesis of Pedro Gonnet.
188189
b = b_def[depth].copy()
189190
m = n[depth] - 1
@@ -200,7 +201,7 @@ def _downdate(c, nans, depth):
200201
m -= 1
201202

202203

203-
def _zero_nans(fx):
204+
def _zero_nans(fx: np.ndarray) -> List[int]:
204205
nans = []
205206
for i in range(len(fx)):
206207
if not np.isfinite(fx[i]):
@@ -209,7 +210,7 @@ def _zero_nans(fx):
209210
return nans
210211

211212

212-
def _calc_coeffs(fx, depth):
213+
def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
213214
"""Caution: this function modifies fx."""
214215
nans = _zero_nans(fx)
215216
c_new = V_inv[depth] @ fx
@@ -220,7 +221,7 @@ def _calc_coeffs(fx, depth):
220221

221222

222223
class DivergentIntegralError(ValueError):
223-
def __init__(self, msg, igral, err, nr_points):
224+
def __init__(self, msg: str, igral: float, err: None, nr_points: int) -> None:
224225
self.igral = igral
225226
self.err = err
226227
self.nr_points = nr_points
@@ -230,19 +231,23 @@ def __init__(self, msg, igral, err, nr_points):
230231
class _Interval:
231232
__slots__ = ["a", "b", "c", "fx", "igral", "err", "depth", "rdepth", "ndiv", "c00"]
232233

233-
def __init__(self, a, b, depth, rdepth):
234+
def __init__(
235+
self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int
236+
) -> None:
234237
self.a = a
235238
self.b = b
236239
self.depth = depth
237240
self.rdepth = rdepth
238241

239-
def points(self):
242+
def points(self) -> np.ndarray:
240243
a = self.a
241244
b = self.b
242245
return (a + b) / 2 + (b - a) * xi[self.depth] / 2
243246

244247
@classmethod
245-
def make_first(cls, f, a, b, depth=2):
248+
def make_first(
249+
cls, f: Callable, a: int, b: int, depth: int = 2
250+
) -> Tuple["_Interval", int]:
246251
ival = _Interval(a, b, depth, 1)
247252
fx = f(ival.points())
248253
ival.c = _calc_coeffs(fx, depth)
@@ -251,7 +256,7 @@ def make_first(cls, f, a, b, depth=2):
251256
ival.ndiv = 0
252257
return ival, n[depth]
253258

254-
def calc_igral_and_err(self, c_old):
259+
def calc_igral_and_err(self, c_old: np.ndarray) -> float:
255260
self.c = c_new = _calc_coeffs(self.fx, self.depth)
256261
c_diff = np.zeros(max(len(c_old), len(c_new)))
257262
c_diff[: len(c_old)] = c_old
@@ -262,7 +267,9 @@ def calc_igral_and_err(self, c_old):
262267
self.err = w * c_diff
263268
return c_diff
264269

265-
def split(self, f):
270+
def split(
271+
self, f: Callable
272+
) -> Union[Tuple[Tuple[float, float, float], int], Tuple[List["_Interval"], int]]:
266273
m = (self.a + self.b) / 2
267274
f_center = self.fx[(len(self.fx) - 1) // 2]
268275

@@ -287,7 +294,7 @@ def split(self, f):
287294

288295
return ivals, nr_points
289296

290-
def refine(self, f):
297+
def refine(self, f: Callable) -> Tuple[np.ndarray, bool, int]:
291298
"""Increase degree of interval."""
292299
self.depth = depth = self.depth + 1
293300
points = self.points()
@@ -299,7 +306,9 @@ def refine(self, f):
299306
return points, split, n[depth] - n[depth - 1]
300307

301308

302-
def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
309+
def algorithm_4(
310+
f: Callable, a: int, b: int, tol: float, N_loops: int = int(1e9)
311+
) -> Tuple[float, float, int, List["_Interval"]]:
303312
"""ALGORITHM_4 evaluates an integral using adaptive quadrature. The
304313
algorithm uses Clenshaw-Curtis quadrature rules of increasing
305314
degree in each interval and bisects the interval if either the
@@ -403,37 +412,39 @@ def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
403412
return igral, err, nr_points, ivals
404413

405414

406-
################ Tests ################
415+
# ############### Tests ################
407416

408417

409-
def f0(x):
418+
def f0(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
410419
return x * np.sin(1 / x) * np.sqrt(abs(1 - x))
411420

412421

413422
def f7(x):
414423
return x**-0.5
415424

416425

417-
def f24(x):
426+
def f24(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
418427
return np.floor(np.exp(x))
419428

420429

421-
def f21(x):
430+
def f21(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
422431
y = 0
423432
for i in range(1, 4):
424433
y += 1 / np.cosh(20**i * (x - 2 * i / 10))
425434
return y
426435

427436

428-
def f63(x, alpha, beta):
437+
def f63(
438+
x: Union[float, np.ndarray], alpha: float, beta: float
439+
) -> Union[float, np.ndarray]:
429440
return abs(x - beta) ** alpha
430441

431442

432443
def F63(x, alpha, beta):
433444
return (x - beta) * abs(x - beta) ** alpha / (alpha + 1)
434445

435446

436-
def fdiv(x):
447+
def fdiv(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
437448
return abs(x - 0.987654321) ** -1.1
438449

439450

@@ -461,7 +472,9 @@ def test_scalar_product(n=33):
461472
selection = [0, 5, 7, n - 1]
462473
for i in selection:
463474
for j in selection:
464-
assert scalar_product(legs[i], legs[j]) == ((i == j) and Frac(2, 2 * i + 1))
475+
assert scalar_product(legs[i], legs[j]) == (
476+
(i == j) and Fraction(2, 2 * i + 1)
477+
)
465478

466479

467480
def simple_newton(n):

0 commit comments

Comments
 (0)