Skip to content

Commit 82ed0a4

Browse files
authored
Remove _RequireAttrsABCMeta metaclass and replace with simple check (#409)
* Remove _RequireAttrsABCMeta metaclass and replace with simple check * make BaseLearner a ABC
1 parent b16f0e5 commit 82ed0a4

11 files changed

+25
-20
lines changed

adaptive/learner/average_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
self.min_npoints = max(min_npoints, 2)
7676
self.sum_f: Real = 0.0
7777
self.sum_f_sq: Real = 0.0
78+
self._check_required_attributes()
7879

7980
def new(self) -> AverageLearner:
8081
"""Create a copy of `~adaptive.AverageLearner` without the data."""

adaptive/learner/average_learner1D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
self._distances: dict[Real, float] = decreasing_dict()
126126
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
127127
self.rescaled_error: dict[Real, float] = decreasing_dict()
128+
self._check_required_attributes()
128129

129130
def new(self) -> AverageLearner1D:
130131
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""

adaptive/learner/balancing_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
)
119119

120120
self.strategy: STRATEGY_TYPE = strategy
121+
self._check_required_attributes()
121122

122123
def new(self) -> BalancingLearner:
123124
"""Create a new `BalancingLearner` with the same parameters."""

adaptive/learner/base_learner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import cloudpickle
55

6-
from adaptive.utils import _RequireAttrsABCMeta, load, save
6+
from adaptive.utils import load, save
77

88

99
def uses_nth_neighbors(n: int):
@@ -60,7 +60,7 @@ def _wrapped(loss_per_interval):
6060
return _wrapped
6161

6262

63-
class BaseLearner(metaclass=_RequireAttrsABCMeta):
63+
class BaseLearner(abc.ABC):
6464
"""Base class for algorithms for learning a function 'f: X → Y'.
6565
6666
Attributes
@@ -198,3 +198,16 @@ def __getstate__(self):
198198

199199
def __setstate__(self, state):
200200
self.__dict__ = cloudpickle.loads(state)
201+
202+
def _check_required_attributes(self):
203+
for name, type_ in self.__annotations__.items():
204+
try:
205+
x = getattr(self, name)
206+
except AttributeError:
207+
raise AttributeError(
208+
f"Required attribute {name} not set in __init__."
209+
) from None
210+
else:
211+
if not isinstance(x, type_):
212+
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
213+
raise TypeError(msg)

adaptive/learner/data_saver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, learner: BaseLearner, arg_picker: Callable) -> None:
4545
self.extra_data = OrderedDict()
4646
self.function = learner.function
4747
self.arg_picker = arg_picker
48+
self._check_required_attributes()
4849

4950
def new(self) -> DataSaver:
5051
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""

adaptive/learner/integrator_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> N
389389
ival = _Interval.make_first(*self.bounds)
390390
self.add_ival(ival)
391391
self.first_ival = ival
392+
self._check_required_attributes()
392393

393394
def new(self) -> IntegratorLearner:
394395
"""Create a copy of `~adaptive.Learner2D` without the data."""

adaptive/learner/learner1D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def __init__(
315315
self.__missing_bounds = set(self.bounds) # cache of missing bounds
316316

317317
self._vdim: int | None = None
318+
self._check_required_attributes()
318319

319320
def new(self) -> Learner1D:
320321
"""Create a copy of `~adaptive.Learner1D` without the data."""

adaptive/learner/learner2D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def __init__(
393393
self._ip = self._ip_combined = None
394394

395395
self.stack_size = 10
396+
self._check_required_attributes()
396397

397398
def new(self) -> Learner2D:
398399
return Learner2D(self.function, self.bounds, self.loss_per_triangle)

adaptive/learner/learnerND.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def __init__(self, func, bounds, loss_per_simplex=None):
376376
# _pop_highest_existing_simplex
377377
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
378378

379+
self._check_required_attributes()
380+
379381
def new(self) -> LearnerND:
380382
"""Create a new learner with the same function and bounds."""
381383
return LearnerND(self.function, self.bounds, self.loss_per_simplex)

adaptive/learner/sequence_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self, function, sequence):
9292
self.sequence = copy(sequence)
9393
self.data = SortedDict()
9494
self.pending_points = set()
95+
self._check_required_attributes()
9596

9697
def new(self) -> SequenceLearner:
9798
"""Return a new `~adaptive.SequenceLearner` without the data."""

0 commit comments

Comments
 (0)