1
1
import collections .abc
2
2
import itertools
3
3
import math
4
- import numbers
5
4
from copy import deepcopy
6
5
from typing import (
7
6
Any ,
8
7
Callable ,
9
8
Dict ,
10
- Iterable ,
11
9
List ,
12
10
Literal ,
13
11
Optional ,
14
12
Sequence ,
13
+ Set ,
15
14
Tuple ,
16
15
Union ,
17
16
)
25
24
from adaptive .learner .learnerND import volume
26
25
from adaptive .learner .triangulation import simplex_volume_in_embedding
27
26
from adaptive .notebook_integration import ensure_holoviews
28
- from adaptive .types import Float
27
+ from adaptive .types import Float , Int , Real
29
28
from adaptive .utils import cache_latest
30
29
31
- Point = Tuple [Float , Float ]
30
+ # -- types --
31
+
32
+ # Commonly used types
33
+ Interval = Union [Tuple [float , float ], Tuple [float , float , int ]]
34
+ NeighborsType = Dict [float , List [Optional [float ]]]
35
+
36
+ # Types for loss_per_interval functions
37
+ NoneFloat = Union [Float , None ]
38
+ NoneArray = Union [np .ndarray , None ]
39
+ XsType0 = Tuple [Float , Float ]
40
+ YsType0 = Union [Tuple [Float , Float ], Tuple [np .ndarray , np .ndarray ]]
41
+ XsType1 = Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ]
42
+ YsType1 = Union [
43
+ Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ],
44
+ Tuple [NoneArray , NoneArray , NoneArray , NoneArray ],
45
+ ]
46
+ XsTypeN = Tuple [NoneFloat , ...]
47
+ YsTypeN = Union [Tuple [NoneFloat , ...], Tuple [NoneArray , ...]]
48
+
49
+
50
+ __all__ = [
51
+ "uniform_loss" ,
52
+ "default_loss" ,
53
+ "abs_min_log_loss" ,
54
+ "triangle_loss" ,
55
+ "resolution_loss_function" ,
56
+ "curvature_loss_function" ,
57
+ "Learner1D" ,
58
+ ]
32
59
33
60
34
61
@uses_nth_neighbors (0 )
35
- def uniform_loss (xs : Point , ys : Any ) -> Float :
62
+ def uniform_loss (xs : XsType0 , ys : YsType0 ) -> Float :
36
63
"""Loss function that samples the domain uniformly.
37
64
38
65
Works with `~adaptive.Learner1D` only.
@@ -52,10 +79,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:
52
79
53
80
54
81
@uses_nth_neighbors (0 )
55
- def default_loss (
56
- xs : Point ,
57
- ys : Union [Tuple [Iterable [Float ], Iterable [Float ]], Point ],
58
- ) -> float :
82
+ def default_loss (xs : XsType0 , ys : YsType0 ) -> Float :
59
83
"""Calculate loss on a single interval.
60
84
61
85
Currently returns the rescaled length of the interval. If one of the
@@ -64,28 +88,23 @@ def default_loss(
64
88
"""
65
89
dx = xs [1 ] - xs [0 ]
66
90
if isinstance (ys [0 ], collections .abc .Iterable ):
67
- dy_vec = [abs (a - b ) for a , b in zip (* ys )]
91
+ dy_vec = np . array ( [abs (a - b ) for a , b in zip (* ys )])
68
92
return np .hypot (dx , dy_vec ).max ()
69
93
else :
70
94
dy = ys [1 ] - ys [0 ]
71
95
return np .hypot (dx , dy )
72
96
73
97
74
98
@uses_nth_neighbors (0 )
75
- def abs_min_log_loss (xs , ys ) :
99
+ def abs_min_log_loss (xs : XsType0 , ys : YsType0 ) -> Float :
76
100
"""Calculate loss of a single interval that prioritizes the absolute minimum."""
77
- ys = [ np .log (np .abs (y ).min ()) for y in ys ]
101
+ ys = tuple ( np .log (np .abs (y ).min ()) for y in ys )
78
102
return default_loss (xs , ys )
79
103
80
104
81
105
@uses_nth_neighbors (1 )
82
- def triangle_loss (
83
- xs : Sequence [Optional [Float ]],
84
- ys : Union [
85
- Iterable [Optional [Float ]],
86
- Iterable [Union [Iterable [Float ], None ]],
87
- ],
88
- ) -> float :
106
+ def triangle_loss (xs : XsType1 , ys : YsType1 ) -> Float :
107
+ assert len (xs ) == 4
89
108
xs = [x for x in xs if x is not None ]
90
109
ys = [y for y in ys if y is not None ]
91
110
@@ -102,7 +121,9 @@ def triangle_loss(
102
121
return sum (vol (pts [i : i + 3 ]) for i in range (N )) / N
103
122
104
123
105
- def resolution_loss_function (min_length = 0 , max_length = 1 ):
124
+ def resolution_loss_function (
125
+ min_length : Real = 0 , max_length : Real = 1
126
+ ) -> Callable [[XsType0 , YsType0 ], Float ]:
106
127
"""Loss function that is similar to the `default_loss` function, but you
107
128
can set the maximum and minimum size of an interval.
108
129
@@ -125,7 +146,7 @@ def resolution_loss_function(min_length=0, max_length=1):
125
146
"""
126
147
127
148
@uses_nth_neighbors (0 )
128
- def resolution_loss (xs , ys ) :
149
+ def resolution_loss (xs : XsType0 , ys : YsType0 ) -> Float :
129
150
loss = uniform_loss (xs , ys )
130
151
if loss < min_length :
131
152
# Return zero such that this interval won't be chosen again
@@ -140,11 +161,11 @@ def resolution_loss(xs, ys):
140
161
141
162
142
163
def curvature_loss_function (
143
- area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
144
- ) -> Callable :
164
+ area_factor : Real = 1 , euclid_factor : Real = 0.02 , horizontal_factor : Real = 0.02
165
+ ) -> Callable [[ XsType1 , YsType1 ], Float ] :
145
166
# XXX: add a doc-string
146
167
@uses_nth_neighbors (1 )
147
- def curvature_loss (xs , ys ) :
168
+ def curvature_loss (xs : XsType1 , ys : YsType1 ) -> Float :
148
169
xs_middle = xs [1 :3 ]
149
170
ys_middle = ys [1 :3 ]
150
171
@@ -160,7 +181,7 @@ def curvature_loss(xs, ys):
160
181
return curvature_loss
161
182
162
183
163
- def linspace (x_left : float , x_right : float , n : int ) -> List [float ]:
184
+ def linspace (x_left : Real , x_right : Real , n : Int ) -> List [Float ]:
164
185
"""This is equivalent to
165
186
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
166
187
but it is 15-30 times faster for small 'n'."""
@@ -172,7 +193,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
172
193
return [x_left + step * i for i in range (1 , n )]
173
194
174
195
175
- def _get_neighbors_from_list (xs : np .ndarray ) -> SortedDict :
196
+ def _get_neighbors_from_array (xs : np .ndarray ) -> NeighborsType :
176
197
xs = np .sort (xs )
177
198
xs_left = np .roll (xs , 1 ).tolist ()
178
199
xs_right = np .roll (xs , - 1 ).tolist ()
@@ -182,7 +203,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
182
203
return SortedDict (neighbors )
183
204
184
205
185
- def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
206
+ def _get_intervals (
207
+ x : float , neighbors : NeighborsType , nth_neighbors : int
208
+ ) -> List [Tuple [float , float ]]:
186
209
nn = nth_neighbors
187
210
i = neighbors .index (x )
188
211
start = max (0 , i - nn - 1 )
@@ -237,10 +260,10 @@ class Learner1D(BaseLearner):
237
260
238
261
def __init__ (
239
262
self ,
240
- function : Callable ,
241
- bounds : Tuple [float , float ],
242
- loss_per_interval : Optional [Callable ] = None ,
243
- ) -> None :
263
+ function : Callable [[ Real ], Union [ Float , np . ndarray ]] ,
264
+ bounds : Tuple [Real , Real ],
265
+ loss_per_interval : Optional [Callable [[ XsTypeN , YsTypeN ], Float ] ] = None ,
266
+ ):
244
267
self .function = function # type: ignore
245
268
246
269
if hasattr (loss_per_interval , "nth_neighbors" ):
@@ -255,13 +278,13 @@ def __init__(
255
278
# the learners behavior in the tests.
256
279
self ._recompute_losses_factor = 2
257
280
258
- self .data = {}
259
- self .pending_points = set ()
281
+ self .data : Dict [ Real , Real ] = {}
282
+ self .pending_points : Set [ Real ] = set ()
260
283
261
284
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
262
285
# properties.
263
- self .neighbors = SortedDict ()
264
- self .neighbors_combined = SortedDict ()
286
+ self .neighbors : NeighborsType = SortedDict ()
287
+ self .neighbors_combined : NeighborsType = SortedDict ()
265
288
266
289
# Bounding box [[minx, maxx], [miny, maxy]].
267
290
self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -319,14 +342,14 @@ def loss(self, real: bool = True) -> float:
319
342
max_interval , max_loss = losses .peekitem (0 )
320
343
return max_loss
321
344
322
- def _scale_x (self , x : Optional [float ]) -> Optional [float ]:
345
+ def _scale_x (self , x : Optional [Float ]) -> Optional [Float ]:
323
346
if x is None :
324
347
return None
325
348
return x / self ._scale [0 ]
326
349
327
350
def _scale_y (
328
- self , y : Optional [ Union [Float , np .ndarray ] ]
329
- ) -> Optional [ Union [Float , np .ndarray ] ]:
351
+ self , y : Union [Float , np .ndarray , None ]
352
+ ) -> Union [Float , np .ndarray , None ]:
330
353
if y is None :
331
354
return None
332
355
y_scale = self ._scale [1 ] or 1
@@ -418,7 +441,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
418
441
self .losses_combined [x , b ] = float ("inf" )
419
442
420
443
@staticmethod
421
- def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
444
+ def _find_neighbors (x : float , neighbors : NeighborsType ) -> Any :
422
445
if x in neighbors :
423
446
return neighbors [x ]
424
447
pos = neighbors .bisect_left (x )
@@ -427,7 +450,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
427
450
x_right = keys [pos ] if pos != len (neighbors ) else None
428
451
return x_left , x_right
429
452
430
- def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
453
+ def _update_neighbors (self , x : float , neighbors : NeighborsType ) -> None :
431
454
if x not in neighbors : # The point is new
432
455
x_left , x_right = self ._find_neighbors (x , neighbors )
433
456
neighbors [x ] = [x_left , x_right ]
@@ -461,9 +484,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
461
484
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
462
485
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
463
486
464
- def tell (
465
- self , x : float , y : Union [Float , Sequence [numbers .Number ], np .ndarray ]
466
- ) -> None :
487
+ def tell (self , x : float , y : Union [Float , Sequence [Float ], np .ndarray ]) -> None :
467
488
if x in self .data :
468
489
# The point is already evaluated before
469
490
return
@@ -506,7 +527,17 @@ def tell_pending(self, x: float) -> None:
506
527
self ._update_neighbors (x , self .neighbors_combined )
507
528
self ._update_losses (x , real = False )
508
529
509
- def tell_many (self , xs : Sequence [float ], ys : Sequence [Any ], * , force = False ) -> None :
530
+ def tell_many (
531
+ self ,
532
+ xs : Sequence [Float ],
533
+ ys : Union [
534
+ Sequence [Float ],
535
+ Sequence [Sequence [Float ]],
536
+ Sequence [np .ndarray ],
537
+ ],
538
+ * ,
539
+ force : bool = False
540
+ ) -> None :
510
541
if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
511
542
# Only run this more efficient method if there are
512
543
# at least 2 points and the amount of points added are
@@ -526,8 +557,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
526
557
points_combined = np .hstack ([points_pending , points ])
527
558
528
559
# Generate neighbors
529
- self .neighbors = _get_neighbors_from_list (points )
530
- self .neighbors_combined = _get_neighbors_from_list (points_combined )
560
+ self .neighbors = _get_neighbors_from_array (points )
561
+ self .neighbors_combined = _get_neighbors_from_array (points_combined )
531
562
532
563
# Update scale
533
564
self ._bbox [0 ] = [points_combined .min (), points_combined .max ()]
@@ -574,7 +605,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
574
605
# have an inf loss.
575
606
self ._update_interpolated_loss_in_interval (* ival )
576
607
577
- def ask (self , n : int , tell_pending : bool = True ) -> Any :
608
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [ float ], List [ float ]] :
578
609
"""Return 'n' points that are expected to maximally reduce the loss."""
579
610
points , loss_improvements = self ._ask_points_without_adding (n )
580
611
@@ -584,7 +615,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
584
615
585
616
return points , loss_improvements
586
617
587
- def _ask_points_without_adding (self , n : int ) -> Any :
618
+ def _ask_points_without_adding (self , n : int ) -> Tuple [ List [ float ], List [ float ]] :
588
619
"""Return 'n' points that are expected to maximally reduce the loss.
589
620
Without altering the state of the learner"""
590
621
# Find out how to divide the n points over the intervals
@@ -648,7 +679,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
648
679
quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
649
680
650
681
points = list (
651
- itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
682
+ itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
652
683
)
653
684
654
685
loss_improvements = list (
@@ -663,7 +694,9 @@ def _ask_points_without_adding(self, n: int) -> Any:
663
694
664
695
return points , loss_improvements
665
696
666
- def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
697
+ def _loss (
698
+ self , mapping : Dict [Interval , float ], ival : Interval
699
+ ) -> Tuple [float , Interval ]:
667
700
loss = mapping [ival ]
668
701
return finite_loss (ival , loss , self ._scale [0 ])
669
702
@@ -734,7 +767,7 @@ def __setstate__(self, state):
734
767
self .losses_combined .update (losses_combined )
735
768
736
769
737
- def loss_manager (x_scale : float ) -> ItemSortedDict :
770
+ def loss_manager (x_scale : float ) -> Dict [ Interval , float ] :
738
771
def sort_key (ival , loss ):
739
772
loss , ival = finite_loss (ival , loss , x_scale )
740
773
return - loss , ival
@@ -743,8 +776,8 @@ def sort_key(ival, loss):
743
776
return sorted_dict
744
777
745
778
746
- def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
747
- """Get the socalled finite_loss of an interval in order to be able to
779
+ def finite_loss (ival : Interval , loss : float , x_scale : float ) -> Tuple [ float , Interval ] :
780
+ """Get the so-called finite_loss of an interval in order to be able to
748
781
sort intervals that have infinite loss."""
749
782
# If the loss is infinite we return the
750
783
# distance between the two points.
0 commit comments