From 889feaa5c72dbe71f6df438eb4d12b4f0d9b653f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:41:10 +0200 Subject: [PATCH] more places to simplify --- .../_array_api/_searching_functions.py | 16 ++- .../_array_api/_utility_functions.py | 20 ++-- xarray/namedarray/_array_api/_utils.py | 100 ++++++------------ 3 files changed, 50 insertions(+), 86 deletions(-) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 38e851bf5f8..755ee16b588 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -5,8 +5,8 @@ from xarray.namedarray._array_api._utils import ( _dim_to_optional_axis, _get_data_namespace, - _get_remaining_dims, _infer_dims, + _reduce_dims, ) from xarray.namedarray._typing import ( Default, @@ -32,10 +32,9 @@ def argmax( ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _axis = _dim_to_optional_axis(x, dim, axis) - _data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later - # TODO: Why do we need to do the keepdims ourselves? - _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - return x._new(dims=_dims, data=data_) + _data = xp.argmax(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def argmin( @@ -48,10 +47,9 @@ def argmin( ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _axis = _dim_to_optional_axis(x, dim, axis) - _data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later - # TODO: Why do we need to do the keepdims ourselves? - _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - return x._new(dims=_dims, data=data_) + _data = xp.argmin(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]: diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 17cd5ad03c7..211046e04f2 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -5,7 +5,7 @@ from xarray.namedarray._array_api._utils import ( _dims_to_axis, _get_data_namespace, - _get_remaining_dims, + _reduce_dims, ) from xarray.namedarray._typing import ( Default, @@ -25,11 +25,10 @@ def all( axis: _AxisLike | None = None, ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.all(x._data, axis=axis_, keepdims=False) - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _axis = _dims_to_axis(x, dims, axis) + _data = xp.all(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def any( @@ -41,8 +40,7 @@ def any( axis: _AxisLike | None = None, ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.any(x._data, axis=axis_, keepdims=False) - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _axis = _dims_to_axis(x, dims, axis) + _data = xp.any(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 03ecbe8954d..8257d471e7b 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -337,72 +337,6 @@ def _dim_to_axis(x: NamedArray[Any, Any], dim: _Dim | Default, axis: int) -> int return _axis -def _get_remaining_dims( - x: NamedArray[Any, _DType], - data: duckarray[Any, _DType], - axis: _AxisLike | None, - *, - keepdims: bool, -) -> tuple[_Dims, duckarray[Any, _DType]]: - """ - Get the reamining dims after a reduce operation. - """ - if data.shape == x.shape: - return x.dims, data - - removed_axes: tuple[int, ...] - if axis is None: - removed_axes = tuple(v for v in range(x.ndim)) - else: - removed_axes = _normalize_axis_tuple(axis, x.ndim) - - if keepdims: - # Insert None (aka newaxis) for removed dims - slices = tuple( - None if i in removed_axes else slice(None, None) for i in range(x.ndim) - ) - data = data[slices] - dims = x.dims - else: - dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) - - return dims, data - - -def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: - """ - Reduce dims according to axis. - - Examples - -------- - >>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False) - () - >>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False) - ('x', 'z') - >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False) - ('x', 'y') - - keepdims retains the same dims - - >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True) - ('x', 'y', 'z') - """ - if keepdims: - return dims - - ndim = len(dims) - if axis is None: - _axis = tuple(v for v in range(ndim)) - else: - _axis = _normalize_axis_tuple(axis, ndim) - - key = [slice(None)] * ndim - for i, v in enumerate(_axis): - key[v] = 0 - - return _dims_from_tuple_indexing(dims, tuple(key)) - - def _new_unique_dim_name(dims: _Dims, i: int | None = None) -> _Dim: """ Get a new unique dimension name. @@ -565,6 +499,40 @@ def _atleast1d_dims(dims: _Dims) -> _Dims: return (_new_unique_dim_name(dims),) if len(dims) < 1 else dims +def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: + """ + Reduce dims according to axis. + + Examples + -------- + >>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False) + () + >>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False) + ('x', 'z') + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False) + ('x', 'y') + + keepdims retains the same dims + + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True) + ('x', 'y', 'z') + """ + if keepdims: + return dims + + ndim = len(dims) + if axis is None: + _axis = tuple(v for v in range(ndim)) + else: + _axis = _normalize_axis_tuple(axis, ndim) + + key = [slice(None)] * ndim + for i, v in enumerate(_axis): + key[v] = 0 + + return _dims_from_tuple_indexing(dims, tuple(key)) + + def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: