Skip to content

Type hint IntegratorLearner #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions adaptive/learner/integrator_coeffs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Based on an adaptive quadrature algorithm by Pedro Gonnet
from __future__ import annotations

from collections import defaultdict
from fractions import Fraction
Expand All @@ -8,7 +9,7 @@
import scipy.linalg


def legendre(n):
def legendre(n: int) -> list[list[Fraction]]:
"""Return the first n Legendre polynomials.

The polynomials have *standard* normalization, i.e.
Expand All @@ -29,7 +30,7 @@ def legendre(n):
return result


def newton(n):
def newton(n: int) -> np.ndarray:
"""Compute the monomial coefficients of the Newton polynomial over the
nodes of the n-point Clenshaw-Curtis quadrature rule.
"""
Expand Down Expand Up @@ -86,7 +87,7 @@ def newton(n):
return cf


def scalar_product(a, b):
def scalar_product(a: list[Fraction], b: list[Fraction]) -> Fraction:
"""Compute the polynomial scalar product int_-1^1 dx a(x) b(x).

The args must be sequences of polynomial coefficients. This
Expand All @@ -107,7 +108,7 @@ def scalar_product(a, b):
return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))


def calc_bdef(ns):
def calc_bdef(ns: tuple[int, int, int, int]) -> list[np.ndarray]:
"""Calculate the decompositions of Newton polynomials (over the nodes
of the n-point Clenshaw-Curtis quadrature rule) in terms of
Legandre polynomials.
Expand All @@ -133,7 +134,7 @@ def calc_bdef(ns):
return result


def calc_V(x, n):
def calc_V(x: np.ndarray, n: int) -> np.ndarray:
V = [np.ones(x.shape), x.copy()]
for i in range(2, n):
V.append((2 * i - 1) / i * x * V[-1] - (i - 1) / i * V[-2])
Expand Down
86 changes: 47 additions & 39 deletions adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from math import sqrt
from operator import attrgetter
from typing import TYPE_CHECKING, Callable

import cloudpickle
import numpy as np
Expand All @@ -25,7 +26,7 @@
with_pandas = False


def _downdate(c, nans, depth):
def _downdate(c: np.ndarray, nans: list[int], depth: int) -> np.ndarray:
# This is algorithm 5 from the thesis of Pedro Gonnet.
b = coeff.b_def[depth].copy()
m = coeff.ns[depth] - 1
Expand All @@ -45,7 +46,7 @@ def _downdate(c, nans, depth):
return c


def _zero_nans(fx):
def _zero_nans(fx: np.ndarray) -> list[int]:
"""Caution: this function modifies fx."""
nans = []
for i in range(len(fx)):
Expand All @@ -55,7 +56,7 @@ def _zero_nans(fx):
return nans


def _calc_coeffs(fx, depth):
def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
"""Caution: this function modifies fx."""
nans = _zero_nans(fx)
c_new = coeff.V_inv[depth] @ fx
Expand Down Expand Up @@ -135,27 +136,32 @@ class _Interval:
"removed",
]

def __init__(self, a, b, depth, rdepth):
self.children = []
self.data = {}
def __init__(self, a: int | float, b: int | float, depth: int, rdepth: int) -> None:
self.children: list[_Interval] = []
self.data: dict[float, float] = {}
self.a = a
self.b = b
self.depth = depth
self.rdepth = rdepth
self.done_leaves = set()
self.depth_complete = None
self.done_leaves: set[_Interval] = set()
self.depth_complete: int | None = None
self.removed = False
if TYPE_CHECKING:
self.ndiv: int
self.parent: _Interval | None
self.err: float
self.c: np.ndarray

@classmethod
def make_first(cls, a, b, depth=2):
def make_first(cls, a: int, b: int, depth: int = 2) -> _Interval:
ival = _Interval(a, b, depth, rdepth=1)
ival.ndiv = 0
ival.parent = None
ival.err = sys.float_info.max # needed because inf/2 == inf
return ival

@property
def T(self):
def T(self) -> np.ndarray:
"""Get the correct shift matrix.

Should only be called on children of a split interval.
Expand All @@ -166,24 +172,24 @@ def T(self):
assert left != right
return coeff.T_left if left else coeff.T_right

def refinement_complete(self, depth):
def refinement_complete(self, depth: int) -> bool:
"""The interval has all the y-values to calculate the intergral."""
if len(self.data) < coeff.ns[depth]:
return False
return all(p in self.data for p in self.points(depth))

def points(self, depth=None):
def points(self, depth: int | None = None) -> np.ndarray:
if depth is None:
depth = self.depth
a = self.a
b = self.b
return (a + b) / 2 + (b - a) * coeff.xi[depth] / 2

def refine(self):
def refine(self) -> _Interval:
self.depth += 1
return self

def split(self):
def split(self) -> list[_Interval]:
points = self.points()
m = points[len(points) // 2]
ivals = [
Expand All @@ -198,10 +204,10 @@ def split(self):

return ivals

def calc_igral(self):
def calc_igral(self) -> None:
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)

def update_heuristic_err(self, value):
def update_heuristic_err(self, value: float) -> None:
"""Sets the error of an interval using a heuristic (half the error of
the parent) when the actual error cannot be calculated due to its
parents not being finished yet. This error is propagated down to its
Expand All @@ -214,7 +220,7 @@ def update_heuristic_err(self, value):
continue
child.update_heuristic_err(value / 2)

def calc_err(self, c_old):
def calc_err(self, c_old: np.ndarray) -> float:
c_new = self.c
c_diff = np.zeros(max(len(c_old), len(c_new)))
c_diff[: len(c_old)] = c_old
Expand All @@ -226,9 +232,9 @@ def calc_err(self, c_old):
child.update_heuristic_err(self.err / 2)
return c_diff

def calc_ndiv(self):
def calc_ndiv(self) -> None:
div = self.parent.c00 and self.c00 / self.parent.c00 > 2
self.ndiv += div
self.ndiv += int(div)

if self.ndiv > coeff.ndiv_max and 2 * self.ndiv > self.rdepth:
raise DivergentIntegralError
Expand All @@ -237,15 +243,15 @@ def calc_ndiv(self):
for child in self.children:
child.update_ndiv_recursively()

def update_ndiv_recursively(self):
def update_ndiv_recursively(self) -> None:
self.ndiv += 1
if self.ndiv > coeff.ndiv_max and 2 * self.ndiv > self.rdepth:
raise DivergentIntegralError

for child in self.children:
child.update_ndiv_recursively()

def complete_process(self, depth):
def complete_process(self, depth: int) -> tuple[bool, bool] | tuple[bool, np.bool_]:
"""Calculate the integral contribution and error from this interval,
and update the done leaves of all ancestor intervals."""
assert self.depth_complete is None or self.depth_complete == depth - 1
Expand Down Expand Up @@ -322,7 +328,7 @@ def complete_process(self, depth):

return force_split, remove

def __repr__(self):
def __repr__(self) -> str:
lst = [
f"(a, b)=({self.a:.5f}, {self.b:.5f})",
f"depth={self.depth}",
Expand All @@ -334,7 +340,7 @@ def __repr__(self):


class IntegratorLearner(BaseLearner):
def __init__(self, function, bounds, tol):
def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -368,16 +374,18 @@ def __init__(self, function, bounds, tol):
plot : hv.Scatter
Plots all the points that are evaluated.
"""
self.function = function
self.function = function # type: ignore
self.bounds = bounds
self.tol = tol
self.max_ivals = 1000
self.priority_split = []
self.priority_split: list[_Interval] = []
self.data = {}
self.pending_points = set()
self._stack = []
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
self.ivals = set()
self._stack: list[float] = []
self.x_mapping: dict[float, SortedSet] = defaultdict(
lambda: SortedSet([], key=attrgetter("rdepth"))
)
self.ivals: set[_Interval] = set()
ival = _Interval.make_first(*self.bounds)
self.add_ival(ival)
self.first_ival = ival
Expand All @@ -387,10 +395,10 @@ def new(self) -> IntegratorLearner:
return IntegratorLearner(self.function, self.bounds, self.tol)

@property
def approximating_intervals(self):
def approximating_intervals(self) -> set[_Interval]:
return self.first_ival.done_leaves

def tell(self, point, value):
def tell(self, point: float, value: float) -> None:
if point not in self.x_mapping:
raise ValueError(f"Point {point} doesn't belong to any interval")
self.data[point] = value
Expand Down Expand Up @@ -426,7 +434,7 @@ def tell(self, point, value):
def tell_pending(self):
pass

def propagate_removed(self, ival):
def propagate_removed(self, ival: _Interval) -> None:
def _propagate_removed_down(ival):
ival.removed = True
self.ivals.discard(ival)
Expand All @@ -436,7 +444,7 @@ def _propagate_removed_down(ival):

_propagate_removed_down(ival)

def add_ival(self, ival):
def add_ival(self, ival: _Interval) -> None:
for x in ival.points():
# Update the mappings
self.x_mapping[x].add(ival)
Expand All @@ -447,15 +455,15 @@ def add_ival(self, ival):
self._stack.append(x)
self.ivals.add(ival)

def ask(self, n, tell_pending=True):
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[float], list[float]]:
"""Choose points for learners."""
if not tell_pending:
with restore(self):
return self._ask_and_tell_pending(n)
else:
return self._ask_and_tell_pending(n)

def _ask_and_tell_pending(self, n):
def _ask_and_tell_pending(self, n: int) -> tuple[list[float], list[float]]:
points, loss_improvements = self.pop_from_stack(n)
n_left = n - len(points)
while n_left > 0:
Expand All @@ -471,7 +479,7 @@ def _ask_and_tell_pending(self, n):

return points, loss_improvements

def pop_from_stack(self, n):
def pop_from_stack(self, n: int) -> tuple[list[float], list[float]]:
points = self._stack[:n]
self._stack = self._stack[n:]
loss_improvements = [
Expand All @@ -482,7 +490,7 @@ def pop_from_stack(self, n):
def remove_unfinished(self):
pass

def _fill_stack(self):
def _fill_stack(self) -> list[float]:
# XXX: to-do if all the ivals have err=inf, take the interval
# with the lowest rdepth and no children.
force_split = bool(self.priority_split)
Expand Down Expand Up @@ -518,16 +526,16 @@ def _fill_stack(self):
return self._stack

@property
def npoints(self):
def npoints(self) -> int:
"""Number of evaluated points."""
return len(self.data)

@property
def igral(self):
def igral(self) -> float:
return sum(i.igral for i in self.approximating_intervals)

@property
def err(self):
def err(self) -> float:
if self.approximating_intervals:
err = sum(i.err for i in self.approximating_intervals)
if err > sys.float_info.max:
Expand Down
Loading