Skip to content

Remove _RequireAttrsABCMeta metaclass and replace with simple check #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.min_npoints = max(min_npoints, 2)
self.sum_f: Real = 0.0
self.sum_f_sq: Real = 0.0
self._check_required_attributes()

def new(self) -> AverageLearner:
"""Create a copy of `~adaptive.AverageLearner` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/average_learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self._distances: dict[Real, float] = decreasing_dict()
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
self.rescaled_error: dict[Real, float] = decreasing_dict()
self._check_required_attributes()

def new(self) -> AverageLearner1D:
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
)

self.strategy: STRATEGY_TYPE = strategy
self._check_required_attributes()

def new(self) -> BalancingLearner:
"""Create a new `BalancingLearner` with the same parameters."""
Expand Down
17 changes: 15 additions & 2 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cloudpickle

from adaptive.utils import _RequireAttrsABCMeta, load, save
from adaptive.utils import load, save


def uses_nth_neighbors(n: int):
Expand Down Expand Up @@ -60,7 +60,7 @@ def _wrapped(loss_per_interval):
return _wrapped


class BaseLearner(metaclass=_RequireAttrsABCMeta):
class BaseLearner(abc.ABC):
"""Base class for algorithms for learning a function 'f: X → Y'.

Attributes
Expand Down Expand Up @@ -198,3 +198,16 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = cloudpickle.loads(state)

def _check_required_attributes(self):
for name, type_ in self.__annotations__.items():
try:
x = getattr(self, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
1 change: 1 addition & 0 deletions adaptive/learner/data_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, learner: BaseLearner, arg_picker: Callable) -> None:
self.extra_data = OrderedDict()
self.function = learner.function
self.arg_picker = arg_picker
self._check_required_attributes()

def new(self) -> DataSaver:
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> N
ival = _Interval.make_first(*self.bounds)
self.add_ival(ival)
self.first_ival = ival
self._check_required_attributes()

def new(self) -> IntegratorLearner:
"""Create a copy of `~adaptive.Learner2D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
self.__missing_bounds = set(self.bounds) # cache of missing bounds

self._vdim: int | None = None
self._check_required_attributes()

def new(self) -> Learner1D:
"""Create a copy of `~adaptive.Learner1D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def __init__(
self._ip = self._ip_combined = None

self.stack_size = 10
self._check_required_attributes()

def new(self) -> Learner2D:
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
Expand Down
2 changes: 2 additions & 0 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def __init__(self, func, bounds, loss_per_simplex=None):
# _pop_highest_existing_simplex
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)

self._check_required_attributes()

def new(self) -> LearnerND:
"""Create a new learner with the same function and bounds."""
return LearnerND(self.function, self.bounds, self.loss_per_simplex)
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, function, sequence):
self.sequence = copy(sequence)
self.data = SortedDict()
self.pending_points = set()
self._check_required_attributes()

def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
Expand Down
18 changes: 0 additions & 18 deletions adaptive/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import abc
import concurrent.futures as concurrent
import functools
import gzip
Expand Down Expand Up @@ -90,23 +89,6 @@ def decorator(method):
return decorator


class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
for name, type_ in obj.__annotations__.items():
try:
x = getattr(obj, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
return obj


def _default_parameters(function, function_prefix: str = "function."):
sig = inspect.signature(function)
defaults = {
Expand Down