Skip to content

Commit 1d2496f

Browse files
committed
improve typing
1 parent 9bfbeb8 commit 1d2496f

File tree

1 file changed

+86
-53
lines changed

1 file changed

+86
-53
lines changed

adaptive/learner/learner1D.py

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import collections.abc
22
import itertools
33
import math
4-
import numbers
54
from copy import deepcopy
65
from typing import (
76
Any,
87
Callable,
98
Dict,
10-
Iterable,
119
List,
1210
Literal,
1311
Optional,
1412
Sequence,
13+
Set,
1514
Tuple,
1615
Union,
1716
)
@@ -25,14 +24,42 @@
2524
from adaptive.learner.learnerND import volume
2625
from adaptive.learner.triangulation import simplex_volume_in_embedding
2726
from adaptive.notebook_integration import ensure_holoviews
28-
from adaptive.types import Float
27+
from adaptive.types import Float, Int, Real
2928
from adaptive.utils import cache_latest
3029

31-
Point = Tuple[Float, Float]
30+
# -- types --
31+
32+
# Commonly used types
33+
Interval = Union[Tuple[float, float], Tuple[float, float, int]]
34+
NeighborsType = Dict[float, List[Optional[float]]]
35+
36+
# Types for loss_per_interval functions
37+
NoneFloat = Union[Float, None]
38+
NoneArray = Union[np.ndarray, None]
39+
XsType0 = Tuple[Float, Float]
40+
YsType0 = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
41+
XsType1 = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
42+
YsType1 = Union[
43+
Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat],
44+
Tuple[NoneArray, NoneArray, NoneArray, NoneArray],
45+
]
46+
XsTypeN = Tuple[NoneFloat, ...]
47+
YsTypeN = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
48+
49+
50+
__all__ = [
51+
"uniform_loss",
52+
"default_loss",
53+
"abs_min_log_loss",
54+
"triangle_loss",
55+
"resolution_loss_function",
56+
"curvature_loss_function",
57+
"Learner1D",
58+
]
3259

3360

3461
@uses_nth_neighbors(0)
35-
def uniform_loss(xs: Point, ys: Any) -> Float:
62+
def uniform_loss(xs: XsType0, ys: YsType0) -> Float:
3663
"""Loss function that samples the domain uniformly.
3764
3865
Works with `~adaptive.Learner1D` only.
@@ -52,10 +79,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:
5279

5380

5481
@uses_nth_neighbors(0)
55-
def default_loss(
56-
xs: Point,
57-
ys: Union[Tuple[Iterable[Float], Iterable[Float]], Point],
58-
) -> float:
82+
def default_loss(xs: XsType0, ys: YsType0) -> Float:
5983
"""Calculate loss on a single interval.
6084
6185
Currently returns the rescaled length of the interval. If one of the
@@ -64,28 +88,23 @@ def default_loss(
6488
"""
6589
dx = xs[1] - xs[0]
6690
if isinstance(ys[0], collections.abc.Iterable):
67-
dy_vec = [abs(a - b) for a, b in zip(*ys)]
91+
dy_vec = np.array([abs(a - b) for a, b in zip(*ys)])
6892
return np.hypot(dx, dy_vec).max()
6993
else:
7094
dy = ys[1] - ys[0]
7195
return np.hypot(dx, dy)
7296

7397

7498
@uses_nth_neighbors(0)
75-
def abs_min_log_loss(xs, ys):
99+
def abs_min_log_loss(xs: XsType0, ys: YsType0) -> Float:
76100
"""Calculate loss of a single interval that prioritizes the absolute minimum."""
77-
ys = [np.log(np.abs(y).min()) for y in ys]
101+
ys = tuple(np.log(np.abs(y).min()) for y in ys)
78102
return default_loss(xs, ys)
79103

80104

81105
@uses_nth_neighbors(1)
82-
def triangle_loss(
83-
xs: Sequence[Optional[Float]],
84-
ys: Union[
85-
Iterable[Optional[Float]],
86-
Iterable[Union[Iterable[Float], None]],
87-
],
88-
) -> float:
106+
def triangle_loss(xs: XsType1, ys: YsType1) -> Float:
107+
assert len(xs) == 4
89108
xs = [x for x in xs if x is not None]
90109
ys = [y for y in ys if y is not None]
91110

@@ -102,7 +121,9 @@ def triangle_loss(
102121
return sum(vol(pts[i : i + 3]) for i in range(N)) / N
103122

104123

105-
def resolution_loss_function(min_length=0, max_length=1):
124+
def resolution_loss_function(
125+
min_length: Real = 0, max_length: Real = 1
126+
) -> Callable[[XsType0, YsType0], Float]:
106127
"""Loss function that is similar to the `default_loss` function, but you
107128
can set the maximum and minimum size of an interval.
108129
@@ -125,7 +146,7 @@ def resolution_loss_function(min_length=0, max_length=1):
125146
"""
126147

127148
@uses_nth_neighbors(0)
128-
def resolution_loss(xs, ys):
149+
def resolution_loss(xs: XsType0, ys: YsType0) -> Float:
129150
loss = uniform_loss(xs, ys)
130151
if loss < min_length:
131152
# Return zero such that this interval won't be chosen again
@@ -140,11 +161,11 @@ def resolution_loss(xs, ys):
140161

141162

142163
def curvature_loss_function(
143-
area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
144-
) -> Callable:
164+
area_factor: Real = 1, euclid_factor: Real = 0.02, horizontal_factor: Real = 0.02
165+
) -> Callable[[XsType1, YsType1], Float]:
145166
# XXX: add a doc-string
146167
@uses_nth_neighbors(1)
147-
def curvature_loss(xs, ys):
168+
def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
148169
xs_middle = xs[1:3]
149170
ys_middle = ys[1:3]
150171

@@ -160,7 +181,7 @@ def curvature_loss(xs, ys):
160181
return curvature_loss
161182

162183

163-
def linspace(x_left: float, x_right: float, n: int) -> List[float]:
184+
def linspace(x_left: Real, x_right: Real, n: Int) -> List[Float]:
164185
"""This is equivalent to
165186
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
166187
but it is 15-30 times faster for small 'n'."""
@@ -172,7 +193,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
172193
return [x_left + step * i for i in range(1, n)]
173194

174195

175-
def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
196+
def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
176197
xs = np.sort(xs)
177198
xs_left = np.roll(xs, 1).tolist()
178199
xs_right = np.roll(xs, -1).tolist()
@@ -182,7 +203,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
182203
return SortedDict(neighbors)
183204

184205

185-
def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
206+
def _get_intervals(
207+
x: float, neighbors: NeighborsType, nth_neighbors: int
208+
) -> List[Tuple[float, float]]:
186209
nn = nth_neighbors
187210
i = neighbors.index(x)
188211
start = max(0, i - nn - 1)
@@ -237,10 +260,10 @@ class Learner1D(BaseLearner):
237260

238261
def __init__(
239262
self,
240-
function: Callable,
241-
bounds: Tuple[float, float],
242-
loss_per_interval: Optional[Callable] = None,
243-
) -> None:
263+
function: Callable[[Real], Union[Float, np.ndarray]],
264+
bounds: Tuple[Real, Real],
265+
loss_per_interval: Optional[Callable[[XsTypeN, YsTypeN], Float]] = None,
266+
):
244267
self.function = function # type: ignore
245268

246269
if hasattr(loss_per_interval, "nth_neighbors"):
@@ -255,13 +278,13 @@ def __init__(
255278
# the learners behavior in the tests.
256279
self._recompute_losses_factor = 2
257280

258-
self.data = {}
259-
self.pending_points = set()
281+
self.data: Dict[Real, Real] = {}
282+
self.pending_points: Set[Real] = set()
260283

261284
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
262285
# properties.
263-
self.neighbors = SortedDict()
264-
self.neighbors_combined = SortedDict()
286+
self.neighbors: NeighborsType = SortedDict()
287+
self.neighbors_combined: NeighborsType = SortedDict()
265288

266289
# Bounding box [[minx, maxx], [miny, maxy]].
267290
self._bbox = [list(bounds), [np.inf, -np.inf]]
@@ -319,14 +342,14 @@ def loss(self, real: bool = True) -> float:
319342
max_interval, max_loss = losses.peekitem(0)
320343
return max_loss
321344

322-
def _scale_x(self, x: Optional[float]) -> Optional[float]:
345+
def _scale_x(self, x: Optional[Float]) -> Optional[Float]:
323346
if x is None:
324347
return None
325348
return x / self._scale[0]
326349

327350
def _scale_y(
328-
self, y: Optional[Union[Float, np.ndarray]]
329-
) -> Optional[Union[Float, np.ndarray]]:
351+
self, y: Union[Float, np.ndarray, None]
352+
) -> Union[Float, np.ndarray, None]:
330353
if y is None:
331354
return None
332355
y_scale = self._scale[1] or 1
@@ -418,7 +441,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
418441
self.losses_combined[x, b] = float("inf")
419442

420443
@staticmethod
421-
def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
444+
def _find_neighbors(x: float, neighbors: NeighborsType) -> Any:
422445
if x in neighbors:
423446
return neighbors[x]
424447
pos = neighbors.bisect_left(x)
@@ -427,7 +450,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
427450
x_right = keys[pos] if pos != len(neighbors) else None
428451
return x_left, x_right
429452

430-
def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
453+
def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
431454
if x not in neighbors: # The point is new
432455
x_left, x_right = self._find_neighbors(x, neighbors)
433456
neighbors[x] = [x_left, x_right]
@@ -461,9 +484,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
461484
self._bbox[1][1] = max(self._bbox[1][1], y)
462485
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
463486

464-
def tell(
465-
self, x: float, y: Union[Float, Sequence[numbers.Number], np.ndarray]
466-
) -> None:
487+
def tell(self, x: float, y: Union[Float, Sequence[Float], np.ndarray]) -> None:
467488
if x in self.data:
468489
# The point is already evaluated before
469490
return
@@ -506,7 +527,17 @@ def tell_pending(self, x: float) -> None:
506527
self._update_neighbors(x, self.neighbors_combined)
507528
self._update_losses(x, real=False)
508529

509-
def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None:
530+
def tell_many(
531+
self,
532+
xs: Sequence[Float],
533+
ys: Union[
534+
Sequence[Float],
535+
Sequence[Sequence[Float]],
536+
Sequence[np.ndarray],
537+
],
538+
*,
539+
force: bool = False
540+
) -> None:
510541
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
511542
# Only run this more efficient method if there are
512543
# at least 2 points and the amount of points added are
@@ -526,8 +557,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
526557
points_combined = np.hstack([points_pending, points])
527558

528559
# Generate neighbors
529-
self.neighbors = _get_neighbors_from_list(points)
530-
self.neighbors_combined = _get_neighbors_from_list(points_combined)
560+
self.neighbors = _get_neighbors_from_array(points)
561+
self.neighbors_combined = _get_neighbors_from_array(points_combined)
531562

532563
# Update scale
533564
self._bbox[0] = [points_combined.min(), points_combined.max()]
@@ -574,7 +605,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
574605
# have an inf loss.
575606
self._update_interpolated_loss_in_interval(*ival)
576607

577-
def ask(self, n: int, tell_pending: bool = True) -> Any:
608+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
578609
"""Return 'n' points that are expected to maximally reduce the loss."""
579610
points, loss_improvements = self._ask_points_without_adding(n)
580611

@@ -584,7 +615,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
584615

585616
return points, loss_improvements
586617

587-
def _ask_points_without_adding(self, n: int) -> Any:
618+
def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
588619
"""Return 'n' points that are expected to maximally reduce the loss.
589620
Without altering the state of the learner"""
590621
# Find out how to divide the n points over the intervals
@@ -648,7 +679,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
648679
quals[(*xs, n + 1)] = loss_qual * n / (n + 1)
649680

650681
points = list(
651-
itertools.chain.from_iterable(linspace(a, b, n) for ((a, b), n) in quals)
682+
itertools.chain.from_iterable(linspace(*ival, n) for (*ival, n) in quals)
652683
)
653684

654685
loss_improvements = list(
@@ -663,7 +694,9 @@ def _ask_points_without_adding(self, n: int) -> Any:
663694

664695
return points, loss_improvements
665696

666-
def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
697+
def _loss(
698+
self, mapping: Dict[Interval, float], ival: Interval
699+
) -> Tuple[float, Interval]:
667700
loss = mapping[ival]
668701
return finite_loss(ival, loss, self._scale[0])
669702

@@ -734,7 +767,7 @@ def __setstate__(self, state):
734767
self.losses_combined.update(losses_combined)
735768

736769

737-
def loss_manager(x_scale: float) -> ItemSortedDict:
770+
def loss_manager(x_scale: float) -> Dict[Interval, float]:
738771
def sort_key(ival, loss):
739772
loss, ival = finite_loss(ival, loss, x_scale)
740773
return -loss, ival
@@ -743,8 +776,8 @@ def sort_key(ival, loss):
743776
return sorted_dict
744777

745778

746-
def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
747-
"""Get the socalled finite_loss of an interval in order to be able to
779+
def finite_loss(ival: Interval, loss: float, x_scale: float) -> Tuple[float, Interval]:
780+
"""Get the so-called finite_loss of an interval in order to be able to
748781
sort intervals that have infinite loss."""
749782
# If the loss is infinite we return the
750783
# distance between the two points.

0 commit comments

Comments
 (0)