Skip to content

Commit 0464abe

Browse files
authored
Add Learner.new() method that returns the same but empty learner (#365)
1 parent 914495c commit 0464abe

12 files changed

+73
-10
lines changed

adaptive/learner/average_learner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def __init__(
7474
self.sum_f: Real = 0.0
7575
self.sum_f_sq: Real = 0.0
7676

77+
def new(self) -> AverageLearner:
78+
"""Create a copy of `~adaptive.AverageLearner` without the data."""
79+
return AverageLearner(self.function, self.atol, self.rtol, self.min_npoints)
80+
7781
@property
7882
def n_requested(self) -> int:
7983
return self.npoints + len(self.pending_points)

adaptive/learner/average_learner1D.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def __init__(
125125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126126
self.rescaled_error: dict[Real, float] = decreasing_dict()
127127

128+
def new(self) -> AverageLearner1D:
129+
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
130+
return AverageLearner1D(
131+
self.function,
132+
self.bounds,
133+
self.loss_per_interval,
134+
self.delta,
135+
self.alpha,
136+
self.neighbor_sampling,
137+
self.min_samples,
138+
self.max_samples,
139+
self.min_error,
140+
)
141+
128142
@property
129143
def nsamples(self) -> int:
130144
"""Returns the total number of samples"""

adaptive/learner/balancing_learner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import itertools
24
from collections import defaultdict
35
from collections.abc import Iterable
@@ -96,6 +98,14 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9698

9799
self.strategy = strategy
98100

101+
def new(self) -> BalancingLearner:
102+
"""Create a new `BalancingLearner` with the same parameters."""
103+
return BalancingLearner(
104+
[learner.new() for learner in self.learners],
105+
cdims=self._cdims_default,
106+
strategy=self.strategy,
107+
)
108+
99109
@property
100110
def data(self):
101111
data = {}

adaptive/learner/base_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def _get_data(self):
149149
def _set_data(self):
150150
pass
151151

152+
@abc.abstractmethod
153+
def new(self):
154+
"""Return a new learner with the same function and parameters."""
155+
pass
156+
152157
def copy_from(self, other):
153158
"""Copy over the data from another learner.
154159

adaptive/learner/data_saver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(self, learner, arg_picker):
4545
self.function = learner.function
4646
self.arg_picker = arg_picker
4747

48+
def new(self) -> DataSaver:
49+
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
50+
return DataSaver(self.learner.new(), self.arg_picker)
51+
4852
def __getattr__(self, attr):
4953
return getattr(self.learner, attr)
5054

adaptive/learner/integrator_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Based on an adaptive quadrature algorithm by Pedro Gonnet
2+
from __future__ import annotations
23

34
import sys
45
from collections import defaultdict
@@ -381,6 +382,10 @@ def __init__(self, function, bounds, tol):
381382
self.add_ival(ival)
382383
self.first_ival = ival
383384

385+
def new(self) -> IntegratorLearner:
386+
"""Create a copy of `~adaptive.Learner2D` without the data."""
387+
return IntegratorLearner(self.function, self.bounds, self.tol)
388+
384389
@property
385390
def approximating_intervals(self):
386391
return self.first_ival.done_leaves

adaptive/learner/learner1D.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,15 @@ def __init__(
303303
# The precision in 'x' below which we set losses to 0.
304304
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
305305

306-
self.bounds = list(bounds)
306+
self.bounds = tuple(bounds)
307307
self.__missing_bounds = set(self.bounds) # cache of missing bounds
308308

309309
self._vdim: int | None = None
310310

311+
def new(self) -> Learner1D:
312+
"""Create a copy of `~adaptive.Learner1D` without the data."""
313+
return Learner1D(self.function, self.bounds, self.loss_per_interval)
314+
311315
@property
312316
def vdim(self) -> int:
313317
"""Length of the output of ``learner.function``.

adaptive/learner/learner2D.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ def __init__(self, function, bounds, loss_per_triangle=None):
384384

385385
self.stack_size = 10
386386

387+
def new(self) -> Learner2D:
388+
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
389+
387390
@property
388391
def xy_scale(self):
389392
xy_scale = self._xy_scale

adaptive/learner/learnerND.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ def __init__(self, func, bounds, loss_per_simplex=None):
376376
# _pop_highest_existing_simplex
377377
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
378378

379+
def new(self) -> LearnerND:
380+
"""Create a new learner with the same function and bounds."""
381+
return LearnerND(self.function, self.bounds, self.loss_per_simplex)
382+
379383
@property
380384
def npoints(self):
381385
"""Number of evaluated points."""

adaptive/learner/sequence_learner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def __init__(self, function, sequence):
7777
self.data = SortedDict()
7878
self.pending_points = set()
7979

80+
def new(self) -> SequenceLearner:
81+
"""Return a new `~adaptive.SequenceLearner` without the data."""
82+
return SequenceLearner(self._original_function, self.sequence)
83+
8084
def ask(self, n, tell_pending=True):
8185
indices = []
8286
points = []

0 commit comments

Comments
 (0)