12
12
from xarray .core .variable import Variable
13
13
from xarray .core .combine import merge
14
14
from xarray .core import dtypes , utils
15
- from xarray .core ._typed_ops import DatasetOpsMixin
16
15
17
16
from .treenode import TreeNode , PathType , _init_single_treenode
18
17
@@ -188,7 +187,7 @@ def imag(self):
188
187
else :
189
188
raise AttributeError ("property is not defined for a node with no data" )
190
189
191
- # TODO .loc
190
+ # TODO .loc, __contains__, __iter__, __array__, '__len__',
192
191
193
192
dims .__doc__ = Dataset .dims .__doc__
194
193
variables .__doc__ = Dataset .variables .__doc__
@@ -207,68 +206,71 @@ def imag(self):
207
206
"See the `map_over_subtree` decorator for more details." , width = 117 )
208
207
209
208
210
- def _expose_methods_wrapped_to_map_over_subtree ( obj , method_name , method ):
209
+ def _wrap_then_attach_to_cls ( cls_dict , methods_to_expose , wrap_func = None ):
211
210
"""
212
- Expose given method on node object, but wrapped to map over whole subtree, not just that node object.
213
-
214
- Result is like having written this in obj's class definition:
211
+ Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)
215
212
213
+ Result is like having written this in the classes' definition:
216
214
```
217
- @map_over_subtree
215
+ @wrap_func
218
216
def method_name(self, *args, **kwargs):
219
217
return self.method(*args, **kwargs)
220
218
```
221
- """
222
-
223
- # Expose Dataset method, but wrapped to map over whole subtree when called
224
- # TODO should we be using functools.partialmethod here instead?
225
- mapped_over_tree = functools .partial (map_over_subtree (method ), obj )
226
- setattr (obj , method_name , mapped_over_tree )
227
-
228
- # TODO do we really need this for ops like __add__?
229
- # Add a line to the method's docstring explaining how it's been mapped
230
- method_docstring = method .__doc__
231
- if method_docstring is not None :
232
- updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
233
- obj_method = getattr (obj , method_name )
234
- setattr (obj_method , '__doc__' , updated_method_docstring )
235
219
220
+ Parameters
221
+ ----------
222
+ cls_dict
223
+ The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
224
+ definition.
225
+ methods_to_expose : Iterable[Tuple[str, callable]]
226
+ The method names and definitions supplied as a list of (method_name_string, method) pairs.\
227
+ This format matches the output of inspect.getmembers().
228
+ """
229
+ for method_name , method in methods_to_expose :
230
+ wrapped_method = wrap_func (method ) if wrap_func is not None else method
231
+ cls_dict [method_name ] = wrapped_method
236
232
237
- # TODO equals, broadcast_equals etc.
238
- # TODO do dask-related private methods need to be exposed?
239
- _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
240
- _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , '__contains__' , '__len__' ,
241
- '__bool__' , '__iter__' , '__array__' , 'set_coords' , 'reset_coords' , 'info' ,
242
- 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
243
- 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
244
- 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
245
- 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
246
- 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
247
- 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
248
- 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
249
- 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
250
- _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
251
- _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
252
-
253
- # TODO methods which should not or cannot act over the whole tree, such as .to_array
254
-
233
+ # TODO do we really need this for ops like __add__?
234
+ # Add a line to the method's docstring explaining how it's been mapped
235
+ method_docstring = method .__doc__
236
+ if method_docstring is not None :
237
+ updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
238
+ setattr (cls_dict [method_name ], '__doc__' , updated_method_docstring )
255
239
256
- class DatasetMethodsMixin :
257
- """Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""
258
240
259
- # TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
260
- def _add_dataset_methods (self ):
261
- methods_to_expose = [(method_name , getattr (Dataset , method_name ))
262
- for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
241
+ class MappedDatasetMethodsMixin :
242
+ """
243
+ Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree.
263
244
264
- for method_name , method in methods_to_expose :
265
- _expose_methods_wrapped_to_map_over_subtree (self , method_name , method )
245
+ Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
246
+ """
247
+ # TODO equals, broadcast_equals etc.
248
+ # TODO do dask-related private methods need to be exposed?
249
+ _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
250
+ _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , 'set_coords' , 'reset_coords' , 'info' ,
251
+ 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
252
+ 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
253
+ 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
254
+ 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
255
+ 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
256
+ 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
257
+ 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
258
+ 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
259
+ # TODO unsure if these are called by external functions or not?
260
+ _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
261
+ _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
262
+
263
+ # TODO methods which should not or cannot act over the whole tree, such as .to_array
264
+
265
+ methods_to_expose = [(method_name , getattr (Dataset , method_name ))
266
+ for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
267
+ _wrap_then_attach_to_cls (vars (), methods_to_expose , wrap_func = map_over_subtree )
266
268
267
269
268
270
# TODO implement ArrayReduce type methods
269
271
270
272
271
- class DataTree (TreeNode , DatasetPropertiesMixin , DatasetMethodsMixin ):
273
+ class DataTree (TreeNode , DatasetPropertiesMixin , MappedDatasetMethodsMixin ):
272
274
"""
273
275
A tree-like hierarchical collection of xarray objects.
274
276
@@ -339,15 +341,6 @@ def __init__(
339
341
new_node = self .get_node (path )
340
342
new_node [path ] = data
341
343
342
- # TODO this has to be
343
- self ._add_all_dataset_api ()
344
-
345
- def _add_all_dataset_api (self ):
346
- # Add methods like .isel(), but wrapped to map over subtrees
347
- self ._add_dataset_methods ()
348
-
349
- # TODO add dataset ops here
350
-
351
344
@property
352
345
def ds (self ) -> Dataset :
353
346
return self ._ds
@@ -396,9 +389,6 @@ def _init_single_datatree_node(
396
389
obj = object .__new__ (cls )
397
390
obj = _init_single_treenode (obj , name = name , parent = parent , children = children )
398
391
obj .ds = data
399
-
400
- obj ._add_all_dataset_api ()
401
-
402
392
return obj
403
393
404
394
def __str__ (self ):
@@ -435,7 +425,7 @@ def _single_node_repr(self):
435
425
def __repr__ (self ):
436
426
"""Information about this node, including its relationships to other nodes."""
437
427
# TODO redo this to look like the Dataset repr, but just with child and parent info
438
- parent = self .parent .name if self .parent else "None"
428
+ parent = self .parent .name if self .parent is not None else "None"
439
429
node_str = f"DataNode(name='{ self .name } ', parent='{ parent } ', children={ [c .name for c in self .children ]} ,"
440
430
441
431
if self .has_data :
@@ -554,7 +544,7 @@ def __setitem__(
554
544
except anytree .resolver .ResolverError :
555
545
existing_node = None
556
546
557
- if existing_node :
547
+ if existing_node is not None :
558
548
if isinstance (value , Dataset ):
559
549
# replace whole dataset
560
550
existing_node .ds = Dataset
0 commit comments