5
5
from collections import OrderedDict
6
6
from copy import copy
7
7
from math import sqrt
8
+ from typing import Callable , Iterable
8
9
9
10
import cloudpickle
10
11
import numpy as np
11
12
from scipy import interpolate
13
+ from scipy .interpolate .interpnd import LinearNDInterpolator
12
14
13
15
from adaptive .learner .base_learner import BaseLearner
14
16
from adaptive .learner .triangulation import simplex_volume_in_embedding
15
17
from adaptive .notebook_integration import ensure_holoviews
18
+ from adaptive .types import Bool , Float , Real
16
19
from adaptive .utils import (
17
20
assign_defaults ,
18
21
cache_latest ,
30
33
# Learner2D and helper functions.
31
34
32
35
33
- def deviations (ip ) :
36
+ def deviations (ip : LinearNDInterpolator ) -> list [ np . ndarray ] :
34
37
"""Returns the deviation of the linear estimate.
35
38
36
39
Is useful when defining custom loss functions.
@@ -68,7 +71,7 @@ def deviation(p, v, g):
68
71
return devs
69
72
70
73
71
- def areas (ip ) :
74
+ def areas (ip : LinearNDInterpolator ) -> np . ndarray :
72
75
"""Returns the area per triangle of the triangulation inside
73
76
a `LinearNDInterpolator` instance.
74
77
@@ -89,7 +92,7 @@ def areas(ip):
89
92
return areas
90
93
91
94
92
- def uniform_loss (ip ) :
95
+ def uniform_loss (ip : LinearNDInterpolator ) -> np . ndarray :
93
96
"""Loss function that samples the domain uniformly.
94
97
95
98
Works with `~adaptive.Learner2D` only.
@@ -120,7 +123,9 @@ def uniform_loss(ip):
120
123
return np .sqrt (areas (ip ))
121
124
122
125
123
- def resolution_loss_function (min_distance = 0 , max_distance = 1 ):
126
+ def resolution_loss_function (
127
+ min_distance : float = 0 , max_distance : float = 1
128
+ ) -> Callable [[LinearNDInterpolator ], np .ndarray ]:
124
129
"""Loss function that is similar to the `default_loss` function, but you
125
130
can set the maximimum and minimum size of a triangle.
126
131
@@ -159,7 +164,7 @@ def resolution_loss(ip):
159
164
return resolution_loss
160
165
161
166
162
- def minimize_triangle_surface_loss (ip ) :
167
+ def minimize_triangle_surface_loss (ip : LinearNDInterpolator ) -> np . ndarray :
163
168
"""Loss function that is similar to the distance loss function in the
164
169
`~adaptive.Learner1D`. The loss is the area spanned by the 3D
165
170
vectors of the vertices.
@@ -205,7 +210,7 @@ def _get_vectors(points):
205
210
return np .linalg .norm (np .cross (a , b ) / 2 , axis = 1 )
206
211
207
212
208
- def default_loss (ip ) :
213
+ def default_loss (ip : LinearNDInterpolator ) -> np . ndarray :
209
214
"""Loss function that combines `deviations` and `areas` of the triangles.
210
215
211
216
Works with `~adaptive.Learner2D` only.
@@ -225,7 +230,7 @@ def default_loss(ip):
225
230
return losses
226
231
227
232
228
- def choose_point_in_triangle (triangle , max_badness ) :
233
+ def choose_point_in_triangle (triangle : np . ndarray , max_badness : int ) -> np . ndarray :
229
234
"""Choose a new point in inside a triangle.
230
235
231
236
If the ratio of the longest edge of the triangle squared
@@ -364,7 +369,12 @@ class Learner2D(BaseLearner):
364
369
over each triangle.
365
370
"""
366
371
367
- def __init__ (self , function , bounds , loss_per_triangle = None ):
372
+ def __init__ (
373
+ self ,
374
+ function : Callable ,
375
+ bounds : tuple [tuple [Real , Real ], tuple [Real , Real ]],
376
+ loss_per_triangle : Callable | None = None ,
377
+ ) -> None :
368
378
self .ndim = len (bounds )
369
379
self ._vdim = None
370
380
self .loss_per_triangle = loss_per_triangle or default_loss
@@ -379,7 +389,7 @@ def __init__(self, function, bounds, loss_per_triangle=None):
379
389
380
390
self ._bounds_points = list (itertools .product (* bounds ))
381
391
self ._stack .update ({p : np .inf for p in self ._bounds_points })
382
- self .function = function
392
+ self .function = function # type: ignore
383
393
self ._ip = self ._ip_combined = None
384
394
385
395
self .stack_size = 10
@@ -388,7 +398,7 @@ def new(self) -> Learner2D:
388
398
return Learner2D (self .function , self .bounds , self .loss_per_triangle )
389
399
390
400
@property
391
- def xy_scale (self ):
401
+ def xy_scale (self ) -> np . ndarray :
392
402
xy_scale = self ._xy_scale
393
403
if self .aspect_ratio == 1 :
394
404
return xy_scale
@@ -486,21 +496,21 @@ def load_dataframe(
486
496
self .function , df , function_prefix
487
497
)
488
498
489
- def _scale (self , points ) :
499
+ def _scale (self , points : list [ tuple [ float , float ]] | np . ndarray ) -> np . ndarray :
490
500
points = np .asarray (points , dtype = float )
491
501
return (points - self .xy_mean ) / self .xy_scale
492
502
493
- def _unscale (self , points ) :
503
+ def _unscale (self , points : np . ndarray ) -> np . ndarray :
494
504
points = np .asarray (points , dtype = float )
495
505
return points * self .xy_scale + self .xy_mean
496
506
497
507
@property
498
- def npoints (self ):
508
+ def npoints (self ) -> int :
499
509
"""Number of evaluated points."""
500
510
return len (self .data )
501
511
502
512
@property
503
- def vdim (self ):
513
+ def vdim (self ) -> int :
504
514
"""Length of the output of ``learner.function``.
505
515
If the output is unsized (when it's a scalar)
506
516
then `vdim = 1`.
@@ -516,12 +526,14 @@ def vdim(self):
516
526
return self ._vdim or 1
517
527
518
528
@property
519
- def bounds_are_done (self ):
529
+ def bounds_are_done (self ) -> bool :
520
530
return not any (
521
531
(p in self .pending_points or p in self ._stack ) for p in self ._bounds_points
522
532
)
523
533
524
- def interpolated_on_grid (self , n = None ):
534
+ def interpolated_on_grid (
535
+ self , n : int = None
536
+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray ]:
525
537
"""Get the interpolated data on a grid.
526
538
527
539
Parameters
@@ -553,7 +565,7 @@ def interpolated_on_grid(self, n=None):
553
565
xs , ys = self ._unscale (np .vstack ([xs , ys ]).T ).T
554
566
return xs , ys , zs
555
567
556
- def _data_in_bounds (self ):
568
+ def _data_in_bounds (self ) -> tuple [ np . ndarray , np . ndarray ] :
557
569
if self .data :
558
570
points = np .array (list (self .data .keys ()))
559
571
values = np .array (list (self .data .values ()), dtype = float )
@@ -562,7 +574,7 @@ def _data_in_bounds(self):
562
574
return points [inds ], values [inds ].reshape (- 1 , self .vdim )
563
575
return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
564
576
565
- def _data_interp (self ):
577
+ def _data_interp (self ) -> tuple [ np . ndarray | list [ tuple [ float , float ]], np . ndarray ] :
566
578
if self .pending_points :
567
579
points = list (self .pending_points )
568
580
if self .bounds_are_done :
@@ -575,7 +587,7 @@ def _data_interp(self):
575
587
return points , values
576
588
return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
577
589
578
- def _data_combined (self ):
590
+ def _data_combined (self ) -> tuple [ np . ndarray , np . ndarray ] :
579
591
points , values = self ._data_in_bounds ()
580
592
if not self .pending_points :
581
593
return points , values
@@ -584,7 +596,7 @@ def _data_combined(self):
584
596
values_combined = np .vstack ([values , values_interp ])
585
597
return points_combined , values_combined
586
598
587
- def ip (self ):
599
+ def ip (self ) -> LinearNDInterpolator :
588
600
"""Deprecated, use `self.interpolator(scaled=True)`"""
589
601
warnings .warn (
590
602
"`learner.ip()` is deprecated, use `learner.interpolator(scaled=True)`."
@@ -593,7 +605,7 @@ def ip(self):
593
605
)
594
606
return self .interpolator (scaled = True )
595
607
596
- def interpolator (self , * , scaled = False ):
608
+ def interpolator (self , * , scaled : bool = False ) -> LinearNDInterpolator :
597
609
"""A `scipy.interpolate.LinearNDInterpolator` instance
598
610
containing the learner's data.
599
611
@@ -624,7 +636,7 @@ def interpolator(self, *, scaled=False):
624
636
points , values = self ._data_in_bounds ()
625
637
return interpolate .LinearNDInterpolator (points , values )
626
638
627
- def _interpolator_combined (self ):
639
+ def _interpolator_combined (self ) -> LinearNDInterpolator :
628
640
"""A `scipy.interpolate.LinearNDInterpolator` instance
629
641
containing the learner's data *and* interpolated data of
630
642
the `pending_points`."""
@@ -634,12 +646,12 @@ def _interpolator_combined(self):
634
646
self ._ip_combined = interpolate .LinearNDInterpolator (points , values )
635
647
return self ._ip_combined
636
648
637
- def inside_bounds (self , xy ) :
649
+ def inside_bounds (self , xy : tuple [ float , float ]) -> Bool :
638
650
x , y = xy
639
651
(xmin , xmax ), (ymin , ymax ) = self .bounds
640
652
return xmin <= x <= xmax and ymin <= y <= ymax
641
653
642
- def tell (self , point , value ) :
654
+ def tell (self , point : tuple [ float , float ], value : float | Iterable [ float ]) -> None :
643
655
point = tuple (point )
644
656
self .data [point ] = value
645
657
if not self .inside_bounds (point ):
@@ -648,15 +660,17 @@ def tell(self, point, value):
648
660
self ._ip = None
649
661
self ._stack .pop (point , None )
650
662
651
- def tell_pending (self , point ) :
663
+ def tell_pending (self , point : tuple [ float , float ]) -> None :
652
664
point = tuple (point )
653
665
if not self .inside_bounds (point ):
654
666
return
655
667
self .pending_points .add (point )
656
668
self ._ip_combined = None
657
669
self ._stack .pop (point , None )
658
670
659
- def _fill_stack (self , stack_till = 1 ):
671
+ def _fill_stack (
672
+ self , stack_till : int = 1
673
+ ) -> tuple [list [tuple [float , float ]], list [float ]]:
660
674
if len (self .data ) + len (self .pending_points ) < self .ndim + 1 :
661
675
raise ValueError ("too few points..." )
662
676
@@ -695,7 +709,9 @@ def _fill_stack(self, stack_till=1):
695
709
696
710
return points_new , losses_new
697
711
698
- def ask (self , n , tell_pending = True ):
712
+ def ask (
713
+ self , n : int , tell_pending : bool = True
714
+ ) -> tuple [list [tuple [float , float ] | np .ndarray ], list [float ]]:
699
715
# Even if tell_pending is False we add the point such that _fill_stack
700
716
# will return new points, later we remove these points if needed.
701
717
points = list (self ._stack .keys ())
@@ -726,14 +742,14 @@ def ask(self, n, tell_pending=True):
726
742
return points [:n ], loss_improvements [:n ]
727
743
728
744
@cache_latest
729
- def loss (self , real = True ):
745
+ def loss (self , real : bool = True ) -> float :
730
746
if not self .bounds_are_done :
731
747
return np .inf
732
748
ip = self .interpolator (scaled = True ) if real else self ._interpolator_combined ()
733
749
losses = self .loss_per_triangle (ip )
734
750
return losses .max ()
735
751
736
- def remove_unfinished (self ):
752
+ def remove_unfinished (self ) -> None :
737
753
self .pending_points = set ()
738
754
for p in self ._bounds_points :
739
755
if p not in self .data :
@@ -807,10 +823,10 @@ def plot(self, n=None, tri_alpha=0):
807
823
808
824
return im .opts (style = im_opts ) * tris .opts (style = tri_opts , ** no_hover )
809
825
810
- def _get_data (self ):
826
+ def _get_data (self ) -> dict [ tuple [ float , float ], Float | np . ndarray ] :
811
827
return self .data
812
828
813
- def _set_data (self , data ) :
829
+ def _set_data (self , data : dict [ tuple [ float , float ], Float | np . ndarray ]) -> None :
814
830
self .data = data
815
831
# Remove points from stack if they already exist
816
832
for point in copy (self ._stack ):
0 commit comments