Skip to content

Commit ef0dbf7

Browse files
committed
Add type-hints to adaptive/learner/learner2D.py
1 parent 1b7e84d commit ef0dbf7

File tree

4 files changed

+50
-33
lines changed

4 files changed

+50
-33
lines changed

adaptive/learner/learner2D.py

+47-31
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
from collections import OrderedDict
66
from copy import copy
77
from math import sqrt
8+
from typing import Callable, Iterable
89

910
import cloudpickle
1011
import numpy as np
1112
from scipy import interpolate
13+
from scipy.interpolate.interpnd import LinearNDInterpolator
1214

1315
from adaptive.learner.base_learner import BaseLearner
1416
from adaptive.learner.triangulation import simplex_volume_in_embedding
1517
from adaptive.notebook_integration import ensure_holoviews
18+
from adaptive.types import Bool, Float, Real
1619
from adaptive.utils import (
1720
assign_defaults,
1821
cache_latest,
@@ -30,7 +33,7 @@
3033
# Learner2D and helper functions.
3134

3235

33-
def deviations(ip):
36+
def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
3437
"""Returns the deviation of the linear estimate.
3538
3639
Is useful when defining custom loss functions.
@@ -68,7 +71,7 @@ def deviation(p, v, g):
6871
return devs
6972

7073

71-
def areas(ip):
74+
def areas(ip: LinearNDInterpolator) -> np.ndarray:
7275
"""Returns the area per triangle of the triangulation inside
7376
a `LinearNDInterpolator` instance.
7477
@@ -89,7 +92,7 @@ def areas(ip):
8992
return areas
9093

9194

92-
def uniform_loss(ip):
95+
def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
9396
"""Loss function that samples the domain uniformly.
9497
9598
Works with `~adaptive.Learner2D` only.
@@ -120,7 +123,9 @@ def uniform_loss(ip):
120123
return np.sqrt(areas(ip))
121124

122125

123-
def resolution_loss_function(min_distance=0, max_distance=1):
126+
def resolution_loss_function(
127+
min_distance: float = 0, max_distance: float = 1
128+
) -> Callable[[LinearNDInterpolator], np.ndarray]:
124129
"""Loss function that is similar to the `default_loss` function, but you
125130
can set the maximimum and minimum size of a triangle.
126131
@@ -159,7 +164,7 @@ def resolution_loss(ip):
159164
return resolution_loss
160165

161166

162-
def minimize_triangle_surface_loss(ip):
167+
def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray:
163168
"""Loss function that is similar to the distance loss function in the
164169
`~adaptive.Learner1D`. The loss is the area spanned by the 3D
165170
vectors of the vertices.
@@ -205,7 +210,7 @@ def _get_vectors(points):
205210
return np.linalg.norm(np.cross(a, b) / 2, axis=1)
206211

207212

208-
def default_loss(ip):
213+
def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
209214
"""Loss function that combines `deviations` and `areas` of the triangles.
210215
211216
Works with `~adaptive.Learner2D` only.
@@ -225,7 +230,7 @@ def default_loss(ip):
225230
return losses
226231

227232

228-
def choose_point_in_triangle(triangle, max_badness):
233+
def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarray:
229234
"""Choose a new point in inside a triangle.
230235
231236
If the ratio of the longest edge of the triangle squared
@@ -364,7 +369,12 @@ class Learner2D(BaseLearner):
364369
over each triangle.
365370
"""
366371

367-
def __init__(self, function, bounds, loss_per_triangle=None):
372+
def __init__(
373+
self,
374+
function: Callable,
375+
bounds: tuple[tuple[Real, Real], tuple[Real, Real]],
376+
loss_per_triangle: Callable | None = None,
377+
) -> None:
368378
self.ndim = len(bounds)
369379
self._vdim = None
370380
self.loss_per_triangle = loss_per_triangle or default_loss
@@ -379,7 +389,7 @@ def __init__(self, function, bounds, loss_per_triangle=None):
379389

380390
self._bounds_points = list(itertools.product(*bounds))
381391
self._stack.update({p: np.inf for p in self._bounds_points})
382-
self.function = function
392+
self.function = function # type: ignore
383393
self._ip = self._ip_combined = None
384394

385395
self.stack_size = 10
@@ -388,7 +398,7 @@ def new(self) -> Learner2D:
388398
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
389399

390400
@property
391-
def xy_scale(self):
401+
def xy_scale(self) -> np.ndarray:
392402
xy_scale = self._xy_scale
393403
if self.aspect_ratio == 1:
394404
return xy_scale
@@ -486,21 +496,21 @@ def load_dataframe(
486496
self.function, df, function_prefix
487497
)
488498

489-
def _scale(self, points):
499+
def _scale(self, points: list[tuple[float, float]] | np.ndarray) -> np.ndarray:
490500
points = np.asarray(points, dtype=float)
491501
return (points - self.xy_mean) / self.xy_scale
492502

493-
def _unscale(self, points):
503+
def _unscale(self, points: np.ndarray) -> np.ndarray:
494504
points = np.asarray(points, dtype=float)
495505
return points * self.xy_scale + self.xy_mean
496506

497507
@property
498-
def npoints(self):
508+
def npoints(self) -> int:
499509
"""Number of evaluated points."""
500510
return len(self.data)
501511

502512
@property
503-
def vdim(self):
513+
def vdim(self) -> int:
504514
"""Length of the output of ``learner.function``.
505515
If the output is unsized (when it's a scalar)
506516
then `vdim = 1`.
@@ -516,12 +526,14 @@ def vdim(self):
516526
return self._vdim or 1
517527

518528
@property
519-
def bounds_are_done(self):
529+
def bounds_are_done(self) -> bool:
520530
return not any(
521531
(p in self.pending_points or p in self._stack) for p in self._bounds_points
522532
)
523533

524-
def interpolated_on_grid(self, n=None):
534+
def interpolated_on_grid(
535+
self, n: int = None
536+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
525537
"""Get the interpolated data on a grid.
526538
527539
Parameters
@@ -553,7 +565,7 @@ def interpolated_on_grid(self, n=None):
553565
xs, ys = self._unscale(np.vstack([xs, ys]).T).T
554566
return xs, ys, zs
555567

556-
def _data_in_bounds(self):
568+
def _data_in_bounds(self) -> tuple[np.ndarray, np.ndarray]:
557569
if self.data:
558570
points = np.array(list(self.data.keys()))
559571
values = np.array(list(self.data.values()), dtype=float)
@@ -562,7 +574,7 @@ def _data_in_bounds(self):
562574
return points[inds], values[inds].reshape(-1, self.vdim)
563575
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
564576

565-
def _data_interp(self):
577+
def _data_interp(self) -> tuple[np.ndarray | list[tuple[float, float]], np.ndarray]:
566578
if self.pending_points:
567579
points = list(self.pending_points)
568580
if self.bounds_are_done:
@@ -575,7 +587,7 @@ def _data_interp(self):
575587
return points, values
576588
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
577589

578-
def _data_combined(self):
590+
def _data_combined(self) -> tuple[np.ndarray, np.ndarray]:
579591
points, values = self._data_in_bounds()
580592
if not self.pending_points:
581593
return points, values
@@ -584,7 +596,7 @@ def _data_combined(self):
584596
values_combined = np.vstack([values, values_interp])
585597
return points_combined, values_combined
586598

587-
def ip(self):
599+
def ip(self) -> LinearNDInterpolator:
588600
"""Deprecated, use `self.interpolator(scaled=True)`"""
589601
warnings.warn(
590602
"`learner.ip()` is deprecated, use `learner.interpolator(scaled=True)`."
@@ -593,7 +605,7 @@ def ip(self):
593605
)
594606
return self.interpolator(scaled=True)
595607

596-
def interpolator(self, *, scaled=False):
608+
def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:
597609
"""A `scipy.interpolate.LinearNDInterpolator` instance
598610
containing the learner's data.
599611
@@ -624,7 +636,7 @@ def interpolator(self, *, scaled=False):
624636
points, values = self._data_in_bounds()
625637
return interpolate.LinearNDInterpolator(points, values)
626638

627-
def _interpolator_combined(self):
639+
def _interpolator_combined(self) -> LinearNDInterpolator:
628640
"""A `scipy.interpolate.LinearNDInterpolator` instance
629641
containing the learner's data *and* interpolated data of
630642
the `pending_points`."""
@@ -634,12 +646,12 @@ def _interpolator_combined(self):
634646
self._ip_combined = interpolate.LinearNDInterpolator(points, values)
635647
return self._ip_combined
636648

637-
def inside_bounds(self, xy):
649+
def inside_bounds(self, xy: tuple[float, float]) -> Bool:
638650
x, y = xy
639651
(xmin, xmax), (ymin, ymax) = self.bounds
640652
return xmin <= x <= xmax and ymin <= y <= ymax
641653

642-
def tell(self, point, value):
654+
def tell(self, point: tuple[float, float], value: float | Iterable[float]) -> None:
643655
point = tuple(point)
644656
self.data[point] = value
645657
if not self.inside_bounds(point):
@@ -648,15 +660,17 @@ def tell(self, point, value):
648660
self._ip = None
649661
self._stack.pop(point, None)
650662

651-
def tell_pending(self, point):
663+
def tell_pending(self, point: tuple[float, float]) -> None:
652664
point = tuple(point)
653665
if not self.inside_bounds(point):
654666
return
655667
self.pending_points.add(point)
656668
self._ip_combined = None
657669
self._stack.pop(point, None)
658670

659-
def _fill_stack(self, stack_till=1):
671+
def _fill_stack(
672+
self, stack_till: int = 1
673+
) -> tuple[list[tuple[float, float]], list[float]]:
660674
if len(self.data) + len(self.pending_points) < self.ndim + 1:
661675
raise ValueError("too few points...")
662676

@@ -695,7 +709,9 @@ def _fill_stack(self, stack_till=1):
695709

696710
return points_new, losses_new
697711

698-
def ask(self, n, tell_pending=True):
712+
def ask(
713+
self, n: int, tell_pending: bool = True
714+
) -> tuple[list[tuple[float, float] | np.ndarray], list[float]]:
699715
# Even if tell_pending is False we add the point such that _fill_stack
700716
# will return new points, later we remove these points if needed.
701717
points = list(self._stack.keys())
@@ -726,14 +742,14 @@ def ask(self, n, tell_pending=True):
726742
return points[:n], loss_improvements[:n]
727743

728744
@cache_latest
729-
def loss(self, real=True):
745+
def loss(self, real: bool = True) -> float:
730746
if not self.bounds_are_done:
731747
return np.inf
732748
ip = self.interpolator(scaled=True) if real else self._interpolator_combined()
733749
losses = self.loss_per_triangle(ip)
734750
return losses.max()
735751

736-
def remove_unfinished(self):
752+
def remove_unfinished(self) -> None:
737753
self.pending_points = set()
738754
for p in self._bounds_points:
739755
if p not in self.data:
@@ -807,10 +823,10 @@ def plot(self, n=None, tri_alpha=0):
807823

808824
return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
809825

810-
def _get_data(self):
826+
def _get_data(self) -> dict[tuple[float, float], Float | np.ndarray]:
811827
return self.data
812828

813-
def _set_data(self, data):
829+
def _set_data(self, data: dict[tuple[float, float], Float | np.ndarray]) -> None:
814830
self.data = data
815831
# Remove points from stack if they already exist
816832
for point in copy(self._stack):

adaptive/tests/test_pickling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def balancing_learner(f, learner_type, learner_kwargs):
6262

6363
learners_pairs = [
6464
(Learner1D, dict(bounds=(-1, 1))),
65-
(Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
65+
(Learner2D, dict(bounds=((-1, 1), (-1, 1)))),
6666
(SequenceLearner, dict(sequence=list(range(100)))),
6767
(IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
6868
(AverageLearner, dict(atol=0.1)),

adaptive/tests/test_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_nonconforming_output(runner):
5757
def f(x):
5858
return [0]
5959

60-
runner(Learner2D(f, [(-1, 1), (-1, 1)]), trivial_goal)
60+
runner(Learner2D(f, ((-1, 1), (-1, 1))), trivial_goal)
6161

6262

6363
def test_aync_def_function():

adaptive/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
Float: TypeAlias = Union[float, np.float_]
1212
Int: TypeAlias = Union[int, np.int_]
1313
Real: TypeAlias = Union[Float, Int]
14+
Bool: TypeAlias = Union[bool, np.bool_]

0 commit comments

Comments
 (0)