diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9ac58ea6534..afd493d2240 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,10 @@ New Features By `Martin Raspaud `_. - Improved static typing of reduction methods (:pull:`6746`). By `Richard Kleijn `_. +- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or + dropping a :py:class:`Dataset`'s variables with missing core dimensions. + (:pull:`8138`) + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2068cdcadd2..6dfff2a24bf 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -32,6 +32,8 @@ from xarray.core.dataset import Dataset from xarray.core.types import CombineAttrsOptions, JoinOptions + MissingCoreDimOptions = Literal["raise", "copy", "drop"] + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) @@ -42,6 +44,7 @@ def _first_of_type(args, kind): for arg in args: if isinstance(arg, kind): return arg + raise ValueError("This should be unreachable.") @@ -347,7 +350,7 @@ def assert_and_return_exact_match(all_keys): if keys != first_keys: raise ValueError( "exact match required for all data variable names, " - f"but {keys!r} != {first_keys!r}" + f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both." ) return first_keys @@ -376,7 +379,7 @@ def collect_dict_values( ] -def _as_variables_or_variable(arg): +def _as_variables_or_variable(arg) -> Variable | tuple[Variable]: try: return arg.variables except AttributeError: @@ -396,8 +399,39 @@ def _unpack_dict_tuples( return out +def _check_core_dims(signature, variable_args, name): + """ + Chcek if an arg has all the core dims required by the signature. + + Slightly awkward design, of returning the error message. But we want to + give a detailed error message, which requires inspecting the variable in + the inner loop. + """ + missing = [] + for i, (core_dims, variable_arg) in enumerate( + zip(signature.input_core_dims, variable_args) + ): + # Check whether all the dims are on the variable. Note that we need the + # `hasattr` to check for a dims property, to protect against the case where + # a numpy array is passed in. + if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims): + missing += [[i, variable_arg, core_dims]] + if missing: + message = "" + for i, variable_arg, core_dims in missing: + message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n" + message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. " + return message + return True + + def apply_dict_of_variables_vfunc( - func, *args, signature: _UFuncSignature, join="inner", fill_value=None + func, + *args, + signature: _UFuncSignature, + join="inner", + fill_value=None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ): """Apply a variable level function over dicts of DataArray, DataArray, Variable and ndarray objects. @@ -408,7 +442,20 @@ def apply_dict_of_variables_vfunc( result_vars = {} for name, variable_args in zip(names, grouped_by_name): - result_vars[name] = func(*variable_args) + core_dim_present = _check_core_dims(signature, variable_args, name) + if core_dim_present is True: + result_vars[name] = func(*variable_args) + else: + if on_missing_core_dim == "raise": + raise ValueError(core_dim_present) + elif on_missing_core_dim == "copy": + result_vars[name] = variable_args[0] + elif on_missing_core_dim == "drop": + pass + else: + raise ValueError( + f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}" + ) if signature.num_outputs > 1: return _unpack_dict_tuples(result_vars, signature.num_outputs) @@ -441,6 +488,7 @@ def apply_dataset_vfunc( fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs="override", + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Dataset | tuple[Dataset, ...]: """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -467,7 +515,12 @@ def apply_dataset_vfunc( args = tuple(getattr(arg, "data_vars", arg) for arg in args) result_vars = apply_dict_of_variables_vfunc( - func, *args, signature=signature, join=dataset_join, fill_value=fill_value + func, + *args, + signature=signature, + join=dataset_join, + fill_value=fill_value, + on_missing_core_dim=on_missing_core_dim, ) out: Dataset | tuple[Dataset, ...] @@ -595,17 +648,9 @@ def broadcast_compat_data( return data set_old_dims = set(old_dims) - missing_core_dims = [d for d in core_dims if d not in set_old_dims] - if missing_core_dims: - raise ValueError( - "operand to apply_ufunc has required core dimensions {}, but " - "some of these dimensions are absent on an input variable: {}".format( - list(core_dims), missing_core_dims - ) - ) - set_new_dims = set(new_dims) unexpected_dims = [d for d in old_dims if d not in set_new_dims] + if unexpected_dims: raise ValueError( "operand to apply_ufunc encountered unexpected " @@ -851,6 +896,7 @@ def apply_ufunc( output_sizes: Mapping[Any, int] | None = None, meta: Any = None, dask_gufunc_kwargs: dict[str, Any] | None = None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -964,6 +1010,8 @@ def apply_ufunc( :py:func:`dask.array.apply_gufunc`. ``meta`` should be given in the ``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter a future version. + on_missing_core_dim : {"raise", "copy", "drop"}, default: "raise" + How to handle missing core dimensions on input variables. Returns ------- @@ -1192,6 +1240,7 @@ def apply_ufunc( dataset_join=dataset_join, fill_value=dataset_fill_value, keep_attrs=keep_attrs, + on_missing_core_dim=on_missing_core_dim, ) # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc elif any(isinstance(a, DataArray) for a in args): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 372998a66bc..052672efb32 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -257,6 +257,168 @@ def func(x): assert_identical(out1, dataset) +def test_apply_missing_dims() -> None: + ## Single arg + + def add_one(a, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda x: x + 1, + a, + input_core_dims=core_dims, + output_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + array = np.arange(6).reshape(2, 3) + variable = xr.Variable(["x", "y"], array) + variable_no_y = xr.Variable(["x", "z"], array) + + ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y}) + + # Check the standard stuff works OK + assert_identical( + add_one(ds[["x_y"]], core_dims=[["y"]], on_missing_core_dim="raise"), + ds[["x_y"]] + 1, + ) + + # `raise` — should raise on a missing dim + with pytest.raises(ValueError): + add_one(ds, core_dims=[["y"]], on_missing_core_dim="raise"), + + # `drop` — should drop the var with the missing dim + assert_identical( + add_one(ds, core_dims=[["y"]], on_missing_core_dim="drop"), + (ds + 1).drop_vars("x_z"), + ) + + # `copy` — should not add one to the missing with `copy` + copy_result = add_one(ds, core_dims=[["y"]], on_missing_core_dim="copy") + assert_identical(copy_result["x_y"], (ds + 1)["x_y"]) + assert_identical(copy_result["x_z"], ds["x_z"]) + + ## Multiple args + + def sum_add(a, b, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda a, b, axis=None: a.sum(axis) + b.sum(axis), + a, + b, + input_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + # Check the standard stuff works OK + assert_identical( + sum_add( + ds[["x_y"]], + ds[["x_y"]], + core_dims=[["x", "y"], ["x", "y"]], + on_missing_core_dim="raise", + ), + ds[["x_y"]].sum() * 2, + ) + + # `raise` — should raise on a missing dim + with pytest.raises( + ValueError, + match=r".*Missing core dims \{'y'\} from arg number 1 on a variable named `x_z`:\n.* None: data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))