Skip to content

Commit

Permalink
Allow expanding single keys to multiple keys.
Browse files Browse the repository at this point in the history
.isel(X=5, Y=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, eta_rho=10, ...)

Closes xarray-contrib#13
  • Loading branch information
dcherian committed Jun 12, 2020
1 parent e4cce59 commit 7efa2fe
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
25 changes: 20 additions & 5 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import inspect
from collections import ChainMap
from typing import Any, List, Optional, Set, Union

import xarray as xr
Expand Down Expand Up @@ -95,7 +96,7 @@ def _get_axis_coord_single(var, key, *args):
results = _get_axis_coord(var, key, *args)
if len(results) > 1:
raise ValueError(
"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
f"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
)
else:
return results[0]
Expand Down Expand Up @@ -335,20 +336,34 @@ def _process_signature(self, func, args, kwargs, key_mappers):
def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
""" rewrites 'dim' for example using 'mapper' """
updates: dict = {}
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord_single))

# allow multiple return values here.
# these are valid for .sel, .isel, .coarsen
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))

for key, mapper in key_mappers.items():
value = kwargs.get(key, None)

if value is not None:
if isinstance(value, str):
value = [value]

if isinstance(value, dict):
# this for things like isel where **kwargs captures things like T=5
updates[key] = {
mapper(self._obj, k, False, k): v for k, v in value.items()
}
# .sel, .isel, .rolling
# Account for multiple names matching the key.
# e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5)
# where xi_* have attrs["axis"] = "X"
updates[key] = ChainMap(
*[
dict.fromkeys(mapper(self._obj, k, False, k), v)
for k, v in value.items()
]
)

elif value is Ellipsis:
pass

else:
# things like sum which have dim
updates[key] = [mapper(self._obj, v, False, v) for v in value]
Expand Down
20 changes: 20 additions & 0 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def test_kwargs_methods(obj):
assert_identical(expected, actual)


def test_kwargs_expand_key_to_multiple_keys():

ds = xr.Dataset()
ds.coords["x1"] = ("x1", range(30), {"axis": "X"})
ds.coords["y1"] = ("y1", range(20), {"axis": "Y"})
ds.coords["x2"] = ("x2", range(10), {"axis": "X"})
ds.coords["y2"] = ("y2", range(5), {"axis": "Y"})

ds["v1"] = (("x1", "y1"), np.ones((30, 20)) * 15)
ds["v2"] = (("x2", "y2"), np.ones((10, 5)) * 15)

actual = ds.cf.isel(X=5, Y=10)
expected = ds.isel(x1=5, y1=10, x2=5, y2=10)
assert_identical(actual, expected)

actual = ds.cf.coarsen(X=10, Y=5)
expected = ds.coarsen(x1=10, y1=5, x2=10, y2=5)
assert_identical(actual.mean(), expected.mean())


@pytest.mark.parametrize("obj", objects)
def test_args_methods(obj):
with raise_if_dask_computes():
Expand Down

0 comments on commit 7efa2fe

Please sign in to comment.