From 7efa2fe4ecfd93015695ef6b95bb4c6ad36238b7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 12 Jun 2020 16:51:33 -0600 Subject: [PATCH] Allow expanding single keys to multiple keys. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit .isel(X=5, Y=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, eta_rho=10, ...) Closes #13 --- cf_xarray/accessor.py | 25 ++++++++++++++++++++----- cf_xarray/tests/test_accessor.py | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index a2c46c6c..13267625 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1,5 +1,6 @@ import functools import inspect +from collections import ChainMap from typing import Any, List, Optional, Set, Union import xarray as xr @@ -95,7 +96,7 @@ def _get_axis_coord_single(var, key, *args): results = _get_axis_coord(var, key, *args) if len(results) > 1: raise ValueError( - "Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue." + f"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue." ) else: return results[0] @@ -335,20 +336,34 @@ def _process_signature(self, func, args, kwargs, key_mappers): def _rewrite_values(self, kwargs, key_mappers: dict, var_kws): """ rewrites 'dim' for example using 'mapper' """ updates: dict = {} - key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord_single)) + + # allow multiple return values here. + # these are valid for .sel, .isel, .coarsen + key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord)) + for key, mapper in key_mappers.items(): value = kwargs.get(key, None) + if value is not None: if isinstance(value, str): value = [value] if isinstance(value, dict): # this for things like isel where **kwargs captures things like T=5 - updates[key] = { - mapper(self._obj, k, False, k): v for k, v in value.items() - } + # .sel, .isel, .rolling + # Account for multiple names matching the key. + # e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5) + # where xi_* have attrs["axis"] = "X" + updates[key] = ChainMap( + *[ + dict.fromkeys(mapper(self._obj, k, False, k), v) + for k, v in value.items() + ] + ) + elif value is Ellipsis: pass + else: # things like sum which have dim updates[key] = [mapper(self._obj, v, False, v) for v in value] diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 287e7fe4..8e3dab30 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -90,6 +90,26 @@ def test_kwargs_methods(obj): assert_identical(expected, actual) +def test_kwargs_expand_key_to_multiple_keys(): + + ds = xr.Dataset() + ds.coords["x1"] = ("x1", range(30), {"axis": "X"}) + ds.coords["y1"] = ("y1", range(20), {"axis": "Y"}) + ds.coords["x2"] = ("x2", range(10), {"axis": "X"}) + ds.coords["y2"] = ("y2", range(5), {"axis": "Y"}) + + ds["v1"] = (("x1", "y1"), np.ones((30, 20)) * 15) + ds["v2"] = (("x2", "y2"), np.ones((10, 5)) * 15) + + actual = ds.cf.isel(X=5, Y=10) + expected = ds.isel(x1=5, y1=10, x2=5, y2=10) + assert_identical(actual, expected) + + actual = ds.cf.coarsen(X=10, Y=5) + expected = ds.coarsen(x1=10, y1=5, x2=10, y2=5) + assert_identical(actual.mean(), expected.mean()) + + @pytest.mark.parametrize("obj", objects) def test_args_methods(obj): with raise_if_dask_computes():