Skip to content

Commit

Permalink
Add drop=True argument to isel, sel and squeeze (#1153)
Browse files Browse the repository at this point in the history
* Add drop=True argument to isel, sel and squeeze

Fixes GH242

This is useful for getting rid of extraneous scalar variables that arise from
indexing, and in particular will resolve an issue for optional indexes:
#1017 (comment)

* More tests for Dataset.squeeze(drop=True)

* Add two more tests, for drop=True without coords
  • Loading branch information
shoyer authored Dec 16, 2016
1 parent 260674d commit 89a6732
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 27 deletions.
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ Enhancements
providing consistent access to dimension length on both ``Dataset`` and
``DataArray`` (:issue:`921`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- New keyword argument ``drop=True`` for :py:meth:`~DataArray.sel`,
:py:meth:`~DataArray.isel` and :py:meth:`~DataArray.squeeze` for dropping
scalar coordinates that arise from indexing.
``DataArray`` (:issue:`242`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

- New top-level functions :py:func:`~xarray.full_like`,
:py:func:`~xarray.zeros_like`, and :py:func:`~xarray.ones_like`
By `Guido Imperiale <https://github.com/crusaderky>`_.
Expand Down
38 changes: 22 additions & 16 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,24 @@ def __dir__(self):
return sorted(set(dir(type(self)) + extra_attrs))


class SharedMethodsMixin(object):
"""Shared methods for Dataset, DataArray and Variable."""
def get_squeeze_dims(xarray_obj, dim):
"""Get a list of dimensions to squeeze out.
"""
if dim is None:
dim = [d for d, s in xarray_obj.sizes.items() if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(xarray_obj.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return dim


class BaseDataObject(AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""

def squeeze(self, dim=None):
def squeeze(self, dim=None, drop=False):
"""Return a new object with squeezed data.
Parameters
Expand All @@ -261,6 +275,9 @@ def squeeze(self, dim=None):
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
drop : bool, optional
If ``drop=True``, drop squeezed coordinates instead of making them
scalar.
Returns
-------
Expand All @@ -272,19 +289,8 @@ def squeeze(self, dim=None):
--------
numpy.squeeze
"""
if dim is None:
dim = [d for d, s in self.sizes.items() if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(self.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return self.isel(**{d: 0 for d in dim})


class BaseDataObject(SharedMethodsMixin, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""
dims = get_squeeze_dims(self, dim)
return self.isel(drop=drop, **{d: 0 for d in dims})

def get_index(self, key):
"""Get an index for a dimension, with fall-back to a default RangeIndex
Expand Down
9 changes: 5 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def chunk(self, chunks=None):
ds = self._to_temp_dataset().chunk(chunks)
return self._from_temp_dataset(ds)

def isel(self, **indexers):
def isel(self, drop=False, **indexers):
"""Return a new DataArray whose dataset is given by integer indexing
along the specified dimension(s).
Expand All @@ -654,10 +654,10 @@ def isel(self, **indexers):
Dataset.isel
DataArray.sel
"""
ds = self._to_temp_dataset().isel(**indexers)
ds = self._to_temp_dataset().isel(drop=drop, **indexers)
return self._from_temp_dataset(ds)

def sel(self, method=None, tolerance=None, **indexers):
def sel(self, method=None, tolerance=None, drop=False, **indexers):
"""Return a new DataArray whose dataset is given by selecting
index labels along the specified dimension(s).
Expand All @@ -669,7 +669,8 @@ def sel(self, method=None, tolerance=None, **indexers):
pos_indexers, new_indexes = indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance
)
return self.isel(**pos_indexers)._replace_indexes(new_indexes)
result = self.isel(drop=drop, **pos_indexers)
return result._replace_indexes(new_indexes)

def isel_points(self, dim='points', **indexers):
"""Return a new DataArray whose dataset is given by pointwise integer
Expand Down
20 changes: 15 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def maybe_chunk(name, var, chunks):
for k, v in self.variables.items()])
return self._replace_vars_and_dims(variables)

def isel(self, **indexers):
def isel(self, drop=False, **indexers):
"""Returns a new dataset with each array indexed along the specified
dimension(s).
Expand All @@ -890,6 +890,9 @@ def isel(self, **indexers):
Parameters
----------
drop : bool, optional
If ``drop=True``, drop coordinates variables indexed by integers
instead of making them scalar.
**indexers : {dim: indexer, ...}
Keyword arguments with names matching dimensions and values given
by integers, slice objects or arrays.
Expand Down Expand Up @@ -923,10 +926,13 @@ def isel(self, **indexers):
variables = OrderedDict()
for name, var in iteritems(self._variables):
var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
variables[name] = var.isel(**var_indexers)
return self._replace_vars_and_dims(variables)
new_var = var.isel(**var_indexers)
if not (drop and name in var_indexers):
variables[name] = new_var
coord_names = set(self._coord_names) & set(variables)
return self._replace_vars_and_dims(variables, coord_names=coord_names)

def sel(self, method=None, tolerance=None, **indexers):
def sel(self, method=None, tolerance=None, drop=False, **indexers):
"""Returns a new dataset with each array indexed by tick labels
along the specified dimension(s).
Expand Down Expand Up @@ -957,6 +963,9 @@ def sel(self, method=None, tolerance=None, **indexers):
matches. The values of the index at the matching locations most
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Requires pandas>=0.17.
drop : bool, optional
If ``drop=True``, drop coordinates variables in `indexers` instead
of making them scalar.
**indexers : {dim: indexer, ...}
Keyword arguments with names matching dimensions and values given
by scalars, slices or arrays of tick labels. For dimensions with
Expand All @@ -982,7 +991,8 @@ def sel(self, method=None, tolerance=None, **indexers):
pos_indexers, new_indexes = indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance
)
return self.isel(**pos_indexers)._replace_indexes(new_indexes)
result = self.isel(drop=drop, **pos_indexers)
return result._replace_indexes(new_indexes)

def isel_points(self, dim='points', **indexers):
"""Returns a new dataset with each array indexed pointwise along the
Expand Down
26 changes: 24 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def _as_array_or_item(data):
return data


class Variable(common.AbstractArray, common.SharedMethodsMixin,
utils.NdimSizeLenMixin):
class Variable(common.AbstractArray, utils.NdimSizeLenMixin):

"""A netcdf-like variable consisting of dimensions, data and attributes
which describe a single Array. A single Variable object is not fully
Expand Down Expand Up @@ -544,6 +543,29 @@ def isel(self, **indexers):
key[i] = indexers[dim]
return self[tuple(key)]

def squeeze(self, dim=None):
"""Return a new object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : same type as caller
This object, but with with all or a subset of the dimensions of
length 1 removed.
See Also
--------
numpy.squeeze
"""
dims = common.get_squeeze_dims(self, dim)
return self.isel(**{d: 0 for d in dims})

def _shift_one_dim(self, dim, count):
axis = self.get_axis_num(dim)

Expand Down
35 changes: 35 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,31 @@ def test_sel_method(self):
with self.assertRaisesRegexp(TypeError, 'tolerance'):
data.sel(x=[0.9, 1.9], method='backfill', tolerance=1)

def test_sel_drop(self):
data = DataArray([1, 2, 3], [('x', [0, 1, 2])])
expected = DataArray(1)
selected = data.sel(x=0, drop=True)
self.assertDataArrayIdentical(expected, selected)

expected = DataArray(1, {'x': 0})
selected = data.sel(x=0, drop=False)
self.assertDataArrayIdentical(expected, selected)

data = DataArray([1, 2, 3], dims=['x'])
expected = DataArray(1)
selected = data.sel(x=0, drop=True)
self.assertDataArrayIdentical(expected, selected)

def test_isel_drop(self):
data = DataArray([1, 2, 3], [('x', [0, 1, 2])])
expected = DataArray(1)
selected = data.isel(x=0, drop=True)
self.assertDataArrayIdentical(expected, selected)

expected = DataArray(1, {'x': 0})
selected = data.isel(x=0, drop=False)
self.assertDataArrayIdentical(expected, selected)

def test_isel_points(self):
shape = (10, 5, 6)
np_array = np.random.random(shape)
Expand Down Expand Up @@ -1117,6 +1142,16 @@ def test_transpose(self):
def test_squeeze(self):
self.assertVariableEqual(self.dv.variable.squeeze(), self.dv.squeeze())

def test_squeeze_drop(self):
array = DataArray([1], [('x', [0])])
expected = DataArray(1)
actual = array.squeeze(drop=True)
self.assertDataArrayIdentical(expected, actual)

expected = DataArray(1, {'x': 0})
actual = array.squeeze(drop=False)
self.assertDataArrayIdentical(expected, actual)

def test_drop_coordinates(self):
expected = DataArray(np.random.randn(2, 3), dims=['x', 'y'])
arr = expected.copy()
Expand Down
48 changes: 48 additions & 0 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,31 @@ def test_sel(self):
self.assertDatasetEqual(data.isel(td=slice(1, 3)),
data.sel(td=slice('1 days', '2 days')))

def test_sel_drop(self):
data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]})
expected = Dataset({'foo': 1})
selected = data.sel(x=0, drop=True)
self.assertDatasetIdentical(expected, selected)

expected = Dataset({'foo': 1}, {'x': 0})
selected = data.sel(x=0, drop=False)
self.assertDatasetIdentical(expected, selected)

data = Dataset({'foo': ('x', [1, 2, 3])})
expected = Dataset({'foo': 1})
selected = data.sel(x=0, drop=True)
self.assertDatasetIdentical(expected, selected)

def test_isel_drop(self):
data = Dataset({'foo': ('x', [1, 2, 3])}, {'x': [0, 1, 2]})
expected = Dataset({'foo': 1})
selected = data.isel(x=0, drop=True)
self.assertDatasetIdentical(expected, selected)

expected = Dataset({'foo': 1}, {'x': 0})
selected = data.isel(x=0, drop=False)
self.assertDatasetIdentical(expected, selected)

def test_isel_points(self):
data = create_test_data()

Expand Down Expand Up @@ -1793,6 +1818,29 @@ def get_args(v):
with self.assertRaisesRegexp(ValueError, 'cannot select a dimension'):
data.squeeze('y')

def test_squeeze_drop(self):
data = Dataset({'foo': ('x', [1])}, {'x': [0]})
expected = Dataset({'foo': 1})
selected = data.squeeze(drop=True)
self.assertDatasetIdentical(expected, selected)

expected = Dataset({'foo': 1}, {'x': 0})
selected = data.squeeze(drop=False)
self.assertDatasetIdentical(expected, selected)

data = Dataset({'foo': (('x', 'y'), [[1]])}, {'x': [0], 'y': [0]})
expected = Dataset({'foo': 1})
selected = data.squeeze(drop=True)
self.assertDatasetIdentical(expected, selected)

expected = Dataset({'foo': ('x', [1])}, {'x': [0]})
selected = data.squeeze(dim='y', drop=True)
self.assertDatasetIdentical(expected, selected)

data = Dataset({'foo': (('x',), [])}, {'x': []})
selected = data.squeeze(drop=True)
self.assertDatasetIdentical(data, selected)

def test_groupby(self):
data = Dataset({'z': (['x', 'y'], np.random.randn(3, 5))},
{'x': ('x', list('abc')),
Expand Down

0 comments on commit 89a6732

Please sign in to comment.