Skip to content

Commit f32346a

Browse files
authored
Add type-hints to adaptive/learner/data_saver.py (#373)
1 parent 6127d72 commit f32346a

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

adaptive/learner/data_saver.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools
44
from collections import OrderedDict
5+
from typing import Any, Callable
56

67
from adaptive.learner.base_learner import BaseLearner
78
from adaptive.utils import copy_docstring_from
@@ -39,7 +40,7 @@ class DataSaver:
3940
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4041
"""
4142

42-
def __init__(self, learner, arg_picker):
43+
def __init__(self, learner: BaseLearner, arg_picker: Callable) -> None:
4344
self.learner = learner
4445
self.extra_data = OrderedDict()
4546
self.function = learner.function
@@ -49,21 +50,21 @@ def new(self) -> DataSaver:
4950
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5051
return DataSaver(self.learner.new(), self.arg_picker)
5152

52-
def __getattr__(self, attr):
53+
def __getattr__(self, attr: str) -> Any:
5354
return getattr(self.learner, attr)
5455

5556
@copy_docstring_from(BaseLearner.tell)
56-
def tell(self, x, result):
57+
def tell(self, x: Any, result: Any) -> None:
5758
y = self.arg_picker(result)
5859
self.extra_data[x] = result
5960
self.learner.tell(x, y)
6061

6162
@copy_docstring_from(BaseLearner.tell_pending)
62-
def tell_pending(self, x):
63+
def tell_pending(self, x: Any) -> None:
6364
self.learner.tell_pending(x)
6465

6566
def to_dataframe(
66-
self, extra_data_name: str = "extra_data", **kwargs
67+
self, extra_data_name: str = "extra_data", **kwargs: Any
6768
) -> pandas.DataFrame:
6869
"""Return the data as a concatenated `pandas.DataFrame` from child learners.
6970
@@ -98,7 +99,7 @@ def load_dataframe(
9899
extra_data_name: str = "extra_data",
99100
input_names: tuple[str] = (),
100101
**kwargs,
101-
):
102+
) -> None:
102103
"""Load the data from a `pandas.DataFrame` into the learner.
103104
104105
Parameters
@@ -122,33 +123,36 @@ def load_dataframe(
122123
key = _to_key(x[:-1])
123124
self.extra_data[key] = x[-1]
124125

125-
def _get_data(self):
126+
def _get_data(self) -> tuple[Any, OrderedDict]:
126127
return self.learner._get_data(), self.extra_data
127128

128-
def _set_data(self, data):
129+
def _set_data(
130+
self,
131+
data: tuple[Any, OrderedDict],
132+
) -> None:
129133
learner_data, self.extra_data = data
130134
self.learner._set_data(learner_data)
131135

132-
def __getstate__(self):
136+
def __getstate__(self) -> tuple[BaseLearner, Callable, OrderedDict]:
133137
return (
134138
self.learner,
135139
self.arg_picker,
136140
self.extra_data,
137141
)
138142

139-
def __setstate__(self, state):
143+
def __setstate__(self, state: tuple[BaseLearner, Callable, OrderedDict]) -> None:
140144
learner, arg_picker, extra_data = state
141145
self.__init__(learner, arg_picker)
142146
self.extra_data = extra_data
143147

144148
@copy_docstring_from(BaseLearner.save)
145-
def save(self, fname, compress=True):
149+
def save(self, fname, compress=True) -> None:
146150
# We copy this method because the 'DataSaver' is not a
147151
# subclass of the 'BaseLearner'.
148152
BaseLearner.save(self, fname, compress)
149153

150154
@copy_docstring_from(BaseLearner.load)
151-
def load(self, fname, compress=True):
155+
def load(self, fname, compress=True) -> None:
152156
# We copy this method because the 'DataSaver' is not a
153157
# subclass of the 'BaseLearner'.
154158
BaseLearner.load(self, fname, compress)

0 commit comments

Comments
 (0)