Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 888c629

Browse files
authored
Merge pull request #19 from TomNicholas/add_api_in_class_definition
Add API methods in class definition
2 parents 415cbb7 + ae3a38a commit 888c629

File tree

1 file changed

+52
-62
lines changed

1 file changed

+52
-62
lines changed

datatree/datatree.py

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from xarray.core.variable import Variable
1313
from xarray.core.combine import merge
1414
from xarray.core import dtypes, utils
15-
from xarray.core._typed_ops import DatasetOpsMixin
1615

1716
from .treenode import TreeNode, PathType, _init_single_treenode
1817

@@ -188,7 +187,7 @@ def imag(self):
188187
else:
189188
raise AttributeError("property is not defined for a node with no data")
190189

191-
# TODO .loc
190+
# TODO .loc, __contains__, __iter__, __array__, '__len__',
192191

193192
dims.__doc__ = Dataset.dims.__doc__
194193
variables.__doc__ = Dataset.variables.__doc__
@@ -207,68 +206,71 @@ def imag(self):
207206
"See the `map_over_subtree` decorator for more details.", width=117)
208207

209208

210-
def _expose_methods_wrapped_to_map_over_subtree(obj, method_name, method):
209+
def _wrap_then_attach_to_cls(cls_dict, methods_to_expose, wrap_func=None):
211210
"""
212-
Expose given method on node object, but wrapped to map over whole subtree, not just that node object.
213-
214-
Result is like having written this in obj's class definition:
211+
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)
215212
213+
Result is like having written this in the classes' definition:
216214
```
217-
@map_over_subtree
215+
@wrap_func
218216
def method_name(self, *args, **kwargs):
219217
return self.method(*args, **kwargs)
220218
```
221-
"""
222-
223-
# Expose Dataset method, but wrapped to map over whole subtree when called
224-
# TODO should we be using functools.partialmethod here instead?
225-
mapped_over_tree = functools.partial(map_over_subtree(method), obj)
226-
setattr(obj, method_name, mapped_over_tree)
227-
228-
# TODO do we really need this for ops like __add__?
229-
# Add a line to the method's docstring explaining how it's been mapped
230-
method_docstring = method.__doc__
231-
if method_docstring is not None:
232-
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
233-
obj_method = getattr(obj, method_name)
234-
setattr(obj_method, '__doc__', updated_method_docstring)
235219
220+
Parameters
221+
----------
222+
cls_dict
223+
The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
224+
definition.
225+
methods_to_expose : Iterable[Tuple[str, callable]]
226+
The method names and definitions supplied as a list of (method_name_string, method) pairs.\
227+
This format matches the output of inspect.getmembers().
228+
"""
229+
for method_name, method in methods_to_expose:
230+
wrapped_method = wrap_func(method) if wrap_func is not None else method
231+
cls_dict[method_name] = wrapped_method
236232

237-
# TODO equals, broadcast_equals etc.
238-
# TODO do dask-related private methods need to be exposed?
239-
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
240-
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', '__contains__', '__len__',
241-
'__bool__', '__iter__', '__array__', 'set_coords', 'reset_coords', 'info',
242-
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
243-
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
244-
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
245-
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
246-
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
247-
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
248-
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
249-
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
250-
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
251-
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
252-
253-
# TODO methods which should not or cannot act over the whole tree, such as .to_array
254-
233+
# TODO do we really need this for ops like __add__?
234+
# Add a line to the method's docstring explaining how it's been mapped
235+
method_docstring = method.__doc__
236+
if method_docstring is not None:
237+
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
238+
setattr(cls_dict[method_name], '__doc__', updated_method_docstring)
255239

256-
class DatasetMethodsMixin:
257-
"""Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""
258240

259-
# TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
260-
def _add_dataset_methods(self):
261-
methods_to_expose = [(method_name, getattr(Dataset, method_name))
262-
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
241+
class MappedDatasetMethodsMixin:
242+
"""
243+
Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree.
263244
264-
for method_name, method in methods_to_expose:
265-
_expose_methods_wrapped_to_map_over_subtree(self, method_name, method)
245+
Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
246+
"""
247+
# TODO equals, broadcast_equals etc.
248+
# TODO do dask-related private methods need to be exposed?
249+
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
250+
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
251+
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
252+
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
253+
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
254+
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
255+
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
256+
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
257+
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
258+
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
259+
# TODO unsure if these are called by external functions or not?
260+
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
261+
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
262+
263+
# TODO methods which should not or cannot act over the whole tree, such as .to_array
264+
265+
methods_to_expose = [(method_name, getattr(Dataset, method_name))
266+
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
267+
_wrap_then_attach_to_cls(vars(), methods_to_expose, wrap_func=map_over_subtree)
266268

267269

268270
# TODO implement ArrayReduce type methods
269271

270272

271-
class DataTree(TreeNode, DatasetPropertiesMixin, DatasetMethodsMixin):
273+
class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
272274
"""
273275
A tree-like hierarchical collection of xarray objects.
274276
@@ -339,15 +341,6 @@ def __init__(
339341
new_node = self.get_node(path)
340342
new_node[path] = data
341343

342-
# TODO this has to be
343-
self._add_all_dataset_api()
344-
345-
def _add_all_dataset_api(self):
346-
# Add methods like .isel(), but wrapped to map over subtrees
347-
self._add_dataset_methods()
348-
349-
# TODO add dataset ops here
350-
351344
@property
352345
def ds(self) -> Dataset:
353346
return self._ds
@@ -396,9 +389,6 @@ def _init_single_datatree_node(
396389
obj = object.__new__(cls)
397390
obj = _init_single_treenode(obj, name=name, parent=parent, children=children)
398391
obj.ds = data
399-
400-
obj._add_all_dataset_api()
401-
402392
return obj
403393

404394
def __str__(self):
@@ -435,7 +425,7 @@ def _single_node_repr(self):
435425
def __repr__(self):
436426
"""Information about this node, including its relationships to other nodes."""
437427
# TODO redo this to look like the Dataset repr, but just with child and parent info
438-
parent = self.parent.name if self.parent else "None"
428+
parent = self.parent.name if self.parent is not None else "None"
439429
node_str = f"DataNode(name='{self.name}', parent='{parent}', children={[c.name for c in self.children]},"
440430

441431
if self.has_data:
@@ -554,7 +544,7 @@ def __setitem__(
554544
except anytree.resolver.ResolverError:
555545
existing_node = None
556546

557-
if existing_node:
547+
if existing_node is not None:
558548
if isinstance(value, Dataset):
559549
# replace whole dataset
560550
existing_node.ds = Dataset

0 commit comments

Comments
 (0)