Skip to content

Add Learner.new() method that returns an empty copy of the learner #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(
self.sum_f: Real = 0.0
self.sum_f_sq: Real = 0.0

def new(self) -> AverageLearner:
"""Create a copy of `~adaptive.AverageLearner` without the data."""
return AverageLearner(self.function, self.atol, self.rtol, self.min_npoints)

@property
def n_requested(self) -> int:
return self.npoints + len(self.pending_points)
Expand Down
14 changes: 14 additions & 0 deletions adaptive/learner/average_learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ def __init__(
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
self.rescaled_error: dict[Real, float] = decreasing_dict()

def new(self) -> AverageLearner1D:
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
return AverageLearner1D(
self.function,
self.bounds,
self.loss_per_interval,
self.delta,
self.alpha,
self.neighbor_sampling,
self.min_samples,
self.max_samples,
self.min_error,
)

@property
def nsamples(self) -> int:
"""Returns the total number of samples"""
Expand Down
10 changes: 10 additions & 0 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import itertools
from collections import defaultdict
from collections.abc import Iterable
Expand Down Expand Up @@ -96,6 +98,14 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):

self.strategy = strategy

def new(self) -> BalancingLearner:
"""Create a new `BalancingLearner` with the same parameters."""
return BalancingLearner(
[learner.new() for learner in self.learners],
cdims=self._cdims_default,
strategy=self.strategy,
)

@property
def data(self):
data = {}
Expand Down
5 changes: 5 additions & 0 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def _get_data(self):
def _set_data(self):
pass

@abc.abstractmethod
def new(self):
"""Return a new learner with the same function and parameters."""
pass

def copy_from(self, other):
"""Copy over the data from another learner.

Expand Down
4 changes: 4 additions & 0 deletions adaptive/learner/data_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, learner, arg_picker):
self.function = learner.function
self.arg_picker = arg_picker

def new(self) -> DataSaver:
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
return DataSaver(self.learner.new(), self.arg_picker)

def __getattr__(self, attr):
return getattr(self.learner, attr)

Expand Down
5 changes: 5 additions & 0 deletions adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Based on an adaptive quadrature algorithm by Pedro Gonnet
from __future__ import annotations

import sys
from collections import defaultdict
Expand Down Expand Up @@ -381,6 +382,10 @@ def __init__(self, function, bounds, tol):
self.add_ival(ival)
self.first_ival = ival

def new(self) -> IntegratorLearner:
"""Create a copy of `~adaptive.Learner2D` without the data."""
return IntegratorLearner(self.function, self.bounds, self.tol)

@property
def approximating_intervals(self):
return self.first_ival.done_leaves
Expand Down
6 changes: 5 additions & 1 deletion adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,15 @@ def __init__(
# The precision in 'x' below which we set losses to 0.
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps

self.bounds = list(bounds)
self.bounds = tuple(bounds)
self.__missing_bounds = set(self.bounds) # cache of missing bounds

self._vdim: int | None = None

def new(self) -> Learner1D:
"""Create a copy of `~adaptive.Learner1D` without the data."""
return Learner1D(self.function, self.bounds, self.loss_per_interval)

@property
def vdim(self) -> int:
"""Length of the output of ``learner.function``.
Expand Down
3 changes: 3 additions & 0 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ def __init__(self, function, bounds, loss_per_triangle=None):

self.stack_size = 10

def new(self) -> Learner2D:
return Learner2D(self.function, self.bounds, self.loss_per_triangle)

@property
def xy_scale(self):
xy_scale = self._xy_scale
Expand Down
4 changes: 4 additions & 0 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def __init__(self, func, bounds, loss_per_simplex=None):
# _pop_highest_existing_simplex
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)

def new(self) -> LearnerND:
"""Create a new learner with the same function and bounds."""
return LearnerND(self.function, self.bounds, self.loss_per_simplex)

@property
def npoints(self):
"""Number of evaluated points."""
Expand Down
4 changes: 4 additions & 0 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(self, function, sequence):
self.data = SortedDict()
self.pending_points = set()

def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
return SequenceLearner(self._original_function, self.sequence)

def ask(self, n, tell_pending=True):
indices = []
points = []
Expand Down
7 changes: 7 additions & 0 deletions adaptive/learner/skopt_learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections

import numpy as np
Expand Down Expand Up @@ -27,8 +29,13 @@ def __init__(self, function, **kwargs):
self.function = function
self.pending_points = set()
self.data = collections.OrderedDict()
self._kwargs = kwargs
super().__init__(**kwargs)

def new(self) -> SKOptLearner:
"""Return a new `~adaptive.SKOptLearner` without the data."""
return SKOptLearner(self.function, **self._kwargs)

def tell(self, x, y, fit=True):
if isinstance(x, collections.abc.Iterable):
self.pending_points.discard(tuple(x))
Expand Down
17 changes: 8 additions & 9 deletions adaptive/tests/test_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
"""
f = generate_random_parametrization(f)
learner = learner_type(f, **learner_kwargs)
control = learner_type(f, **learner_kwargs)
control = learner.new()
if learner_type in (Learner1D, AverageLearner1D):
learner._recompute_losses_factor = 1
control._recompute_losses_factor = 1
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
# XXX: learner, control and bounds are not defined
f = generate_random_parametrization(f)
learner = learner_type(f, **learner_kwargs)
control = learner_type(f, **learner_kwargs)
control = learner.new()

if learner_type is Learner2D:
# If the stack_size is bigger then the number of points added,
Expand Down Expand Up @@ -395,7 +395,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
"""
f = generate_random_parametrization(f)
learner = learner_type(f, **learner_kwargs)
control = learner_type(f, **learner_kwargs)
control = learner.new()

if learner_type in (Learner1D, AverageLearner1D):
learner._recompute_losses_factor = 1
Expand Down Expand Up @@ -581,7 +581,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
def test_saving(learner_type, f, learner_kwargs):
f = generate_random_parametrization(f)
learner = learner_type(f, **learner_kwargs)
control = learner_type(f, **learner_kwargs)
control = learner.new()
if learner_type in (Learner1D, AverageLearner1D):
learner._recompute_losses_factor = 1
control._recompute_losses_factor = 1
Expand Down Expand Up @@ -614,7 +614,7 @@ def test_saving(learner_type, f, learner_kwargs):
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
f = generate_random_parametrization(f)
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
control = BalancingLearner([learner_type(f, **learner_kwargs)])
control = learner.new()

if learner_type in (Learner1D, AverageLearner1D):
for l, c in zip(learner.learners, control.learners):
Expand Down Expand Up @@ -654,7 +654,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
g = lambda x: {"y": f(x), "t": random.random()} # noqa: E731
arg_picker = operator.itemgetter("y")
learner = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
control = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
control = learner.new()

if learner_type in (Learner1D, AverageLearner1D):
learner.learner._recompute_losses_factor = 1
Expand Down Expand Up @@ -742,7 +742,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
assert len(df) == learner.npoints

# Add points from the DataFrame to a new empty learner
learner2 = learner_type(learner.function, **learner_kwargs)
learner2 = learner.new()
learner2.load_dataframe(df, **kw)
assert learner2.npoints == learner.npoints

Expand Down Expand Up @@ -787,8 +787,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
assert len(df) == data_saver.npoints

# Test loading from a DataFrame into a new DataSaver
learner2 = learner_type(learner.function, **learner_kwargs)
data_saver2 = DataSaver(learner2, operator.itemgetter("result"))
data_saver2 = data_saver.new()
data_saver2.load_dataframe(df, **kw)
assert data_saver2.extra_data.keys() == data_saver.extra_data.keys()
assert all(
Expand Down