Skip to content

Add type-hints to BalancingLearner #371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 78 additions & 28 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,47 @@
from __future__ import annotations

import itertools
import numbers
from collections import defaultdict
from collections.abc import Iterable
from contextlib import suppress
from functools import partial
from operator import itemgetter
from typing import Any, Callable, Dict, Sequence, Tuple, Union

import numpy as np

from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.utils import cache_latest, named_product, restore

try:
from typing import Literal, TypeAlias
except ImportError:
from typing_extensions import Literal, TypeAlias

try:
import pandas

with_pandas = True

except ModuleNotFoundError:
with_pandas = False


def dispatch(child_functions, arg):
def dispatch(child_functions: list[Callable], arg: Any) -> Any:
index, x = arg
return child_functions[index](x)


STRATEGY_TYPE: TypeAlias = Literal["loss_improvements", "loss", "npoints", "cycle"]

CDIMS_TYPE: TypeAlias = Union[
Sequence[Dict[str, Any]],
Tuple[Sequence[str], Sequence[Tuple[Any, ...]]],
None,
]


class BalancingLearner(BaseLearner):
r"""Choose the optimal points from a set of learners.

Expand Down Expand Up @@ -78,13 +93,19 @@ class BalancingLearner(BaseLearner):
behave in an undefined way. Change the `strategy` in that case.
"""

def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
def __init__(
self,
learners: list[BaseLearner],
*,
cdims: CDIMS_TYPE = None,
strategy: STRATEGY_TYPE = "loss_improvements",
) -> None:
self.learners = learners

# Naively we would make 'function' a method, but this causes problems
# when using executors from 'concurrent.futures' because we have to
# pickle the whole learner.
self.function = partial(dispatch, [l.function for l in self.learners])
self.function = partial(dispatch, [l.function for l in self.learners]) # type: ignore

self._ask_cache = {}
self._loss = {}
Expand All @@ -96,7 +117,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
"A BalacingLearner can handle only one type" " of learners."
)

self.strategy = strategy
self.strategy: STRATEGY_TYPE = strategy

def new(self) -> BalancingLearner:
"""Create a new `BalancingLearner` with the same parameters."""
Expand All @@ -107,21 +128,21 @@ def new(self) -> BalancingLearner:
)

@property
def data(self):
def data(self) -> dict[tuple[int, Any], Any]:
data = {}
for i, l in enumerate(self.learners):
data.update({(i, p): v for p, v in l.data.items()})
return data

@property
def pending_points(self):
def pending_points(self) -> set[tuple[int, Any]]:
pending_points = set()
for i, l in enumerate(self.learners):
pending_points.update({(i, p) for p in l.pending_points})
return pending_points

@property
def npoints(self):
def npoints(self) -> int:
return sum(l.npoints for l in self.learners)

@property
Expand All @@ -134,7 +155,7 @@ def nsamples(self):
)

@property
def strategy(self):
def strategy(self) -> STRATEGY_TYPE:
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
'cycle'. The points that the `BalancingLearner` choses can be either
based on: the best 'loss_improvements', the smallest total 'loss' of
Expand All @@ -145,7 +166,7 @@ def strategy(self):
return self._strategy

@strategy.setter
def strategy(self, strategy):
def strategy(self, strategy: STRATEGY_TYPE) -> None:
self._strategy = strategy
if strategy == "loss_improvements":
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
Expand All @@ -162,7 +183,9 @@ def strategy(self, strategy):
' strategy="npoints", or strategy="cycle" is implemented.'
)

def _ask_and_tell_based_on_loss_improvements(self, n):
def _ask_and_tell_based_on_loss_improvements(
self, n: int
) -> tuple[list[tuple[int, Any]], list[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -185,7 +208,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_loss(self, n):
def _ask_and_tell_based_on_loss(
self, n: int
) -> tuple[list[tuple[int, Any]], list[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -206,7 +231,9 @@ def _ask_and_tell_based_on_loss(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_npoints(self, n):
def _ask_and_tell_based_on_npoints(
self, n: numbers.Integral
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -222,7 +249,9 @@ def _ask_and_tell_based_on_npoints(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_cycle(self, n):
def _ask_and_tell_based_on_cycle(
self, n: int
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
points, loss_improvements = [], []
for _ in range(n):
index = next(self._cycle)
Expand All @@ -233,7 +262,9 @@ def _ask_and_tell_based_on_cycle(self, n):

return points, loss_improvements

def ask(self, n, tell_pending=True):
def ask(
self, n: int, tell_pending: bool = True
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
"""Chose points for learners."""
if n == 0:
return [], []
Expand All @@ -244,20 +275,20 @@ def ask(self, n, tell_pending=True):
else:
return self._ask_and_tell(n)

def tell(self, x, y):
def tell(self, x: tuple[numbers.Integral, Any], y: Any) -> None:
index, x = x
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self._pending_loss.pop(index, None)
self.learners[index].tell(x, y)

def tell_pending(self, x):
def tell_pending(self, x: tuple[numbers.Integral, Any]) -> None:
index, x = x
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self.learners[index].tell_pending(x)

def _losses(self, real=True):
def _losses(self, real: bool = True) -> list[float]:
losses = []
loss_dict = self._loss if real else self._pending_loss

Expand All @@ -269,11 +300,16 @@ def _losses(self, real=True):
return losses

@cache_latest
def loss(self, real=True):
def loss(self, real: bool = True) -> float:
losses = self._losses(real)
return max(losses)

def plot(self, cdims=None, plotter=None, dynamic=True):
def plot(
self,
cdims: CDIMS_TYPE = None,
plotter: Callable[[BaseLearner], Any] | None = None,
dynamic: bool = True,
):
"""Returns a DynamicMap with sliders.

Parameters
Expand Down Expand Up @@ -346,13 +382,19 @@ def plot_function(*args):
vals = {d.name: d.values for d in dm.dimensions() if d.values}
return hv.HoloMap(dm.select(**vals))

def remove_unfinished(self):
def remove_unfinished(self) -> None:
"""Remove uncomputed data from the learners."""
for learner in self.learners:
learner.remove_unfinished()

@classmethod
def from_product(cls, f, learner_type, learner_kwargs, combos):
def from_product(
cls,
f,
learner_type: BaseLearner,
learner_kwargs: dict[str, Any],
combos: dict[str, Sequence[Any]],
) -> BalancingLearner:
"""Create a `BalancingLearner` with learners of all combinations of
named variables’ values. The `cdims` will be set correctly, so calling
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
Expand Down Expand Up @@ -448,7 +490,11 @@ def load_dataframe(
for i, gr in df.groupby(index_name):
self.learners[i].load_dataframe(gr, **kwargs)

def save(self, fname, compress=True):
def save(
self,
fname: Callable[[BaseLearner], str] | Sequence[str],
compress: bool = True,
) -> None:
"""Save the data of the child learners into pickle files
in a directory.

Expand Down Expand Up @@ -486,7 +532,11 @@ def save(self, fname, compress=True):
for l in self.learners:
l.save(fname(l), compress=compress)

def load(self, fname, compress=True):
def load(
self,
fname: Callable[[BaseLearner], str] | Sequence[str],
compress: bool = True,
) -> None:
"""Load the data of the child learners from pickle files
in a directory.

Expand All @@ -510,20 +560,20 @@ def load(self, fname, compress=True):
for l in self.learners:
l.load(fname(l), compress=compress)

def _get_data(self):
def _get_data(self) -> list[Any]:
return [l._get_data() for l in self.learners]

def _set_data(self, data):
def _set_data(self, data: list[Any]):
for l, _data in zip(self.learners, data):
l._set_data(_data)

def __getstate__(self):
def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
return (
self.learners,
self._cdims_default,
self.strategy,
)

def __setstate__(self, state):
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
learners, cdims, strategy = state
self.__init__(learners, cdims=cdims, strategy=strategy)