Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 79 additions & 50 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import OrderedDict
from contextlib import suppress
from textwrap import dedent
from typing import (Callable, Iterable, Mapping, Optional, Tuple, TypeVar,
Sequence, Union, overload)

import numpy as np
import pandas as pd
Expand All @@ -15,6 +17,9 @@
ALL_DIMS = ReprObject('<all-dims>')


T = TypeVar('T')


class ImplementsArrayReduce(object):
@classmethod
def _reduce_method(cls, func, include_skipna, numeric_only):
Expand Down Expand Up @@ -115,7 +120,15 @@ def __iter__(self):
def T(self):
return self.transpose()

def get_axis_num(self, dim):
@overload
def get_axis_num(self, dim: str) -> int:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more idiomatic to use ... rather than pass when writing overloads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


@overload # noqa:F811
def get_axis_num(self, dim: Iterable[str]) -> Tuple[int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimensions are actually allowed to be non-strings, so it might be better to keep this overload more generic, given that that reflects how the code actually works:

Suggested change
def get_axis_num(self, dim: Iterable[str]) -> Tuple[int]:
def get_axis_num(self, dim: Iterable[Hashable]) -> Tuple[int, ...]:

(Also Tuple[int] should be Tuple[int, ...])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they were intended to be allowed to be non-strings, sure they aren't tested for it!

>>> xarray.DataArray([1,2], dims=[123])
TypeError: dimension 123 is not a string

Also I've never seen it mentioned in the documentation...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed Tuple[int, ...]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It seems that you can only construct these sorts of names/dimensions with the Dataset API:

In [8]: xarray.Dataset({123: ((456,), [1, 2, 3])})
Out[8]:
<xarray.Dataset>
Dimensions:  (456: 3)
Dimensions without coordinates: 456
Data variables:
    123      (456) int64 1 2 3

pass

def get_axis_num(self, dim): # noqa:F811
"""Return axis number(s) corresponding to dimension(s) in this array.

Parameters
Expand All @@ -133,15 +146,15 @@ def get_axis_num(self, dim):
else:
return tuple(self._get_axis_num(d) for d in dim)

def _get_axis_num(self, dim):
def _get_axis_num(self, dim: str) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_axis_num(self, dim: str) -> int:
def _get_axis_num(self, dim: Hashable) -> int:

try:
return self.dims.index(dim)
except ValueError:
raise ValueError("%r not found in array dimensions %r" %
(dim, self.dims))

@property
def sizes(self):
def sizes(self) -> Frozen[str, int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think Frozen is a valid type annotation. Maybe better to stick with Mapping[str, int]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is valid as I changed the Frozen class to be a subclass of typing.Mapping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to Mapping[str, int] for the sake of user friendliness, as Frozen isn't really part of the public API

"""Ordered mapping from dimension names to lengths.

Immutable.
Expand Down Expand Up @@ -214,36 +227,38 @@ def _ipython_key_completions_(self):
return list(set(item_lists))


def get_squeeze_dims(xarray_obj, dim, axis=None):
def get_squeeze_dims(xarray_obj, dim: Union[str, Tuple[str, ...], None] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should be something more relaxed, perhaps dim: Union[Hashable, Iterable[Hashable]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trouble with Iterable is that it doesn't necessarily allow you to iterate upon it twice (e.g. a file handler). I've tweaked the code to allow for it.

axis: Union[int, Tuple[int, ...], None] = None
) -> Sequence[str]:
"""Get a list of dimensions to squeeze out.
"""
if dim is not None and axis is not None:
raise ValueError('cannot use both parameters `axis` and `dim`')

if dim is None and axis is None:
dim = [d for d, s in xarray_obj.sizes.items() if s == 1]
else:
if isinstance(dim, str):
dim = [dim]
if isinstance(axis, int):
axis = (axis, )
if isinstance(axis, tuple):
for a in axis:
if not isinstance(a, int):
raise ValueError(
'parameter `axis` must be int or tuple of int.')
alldims = list(xarray_obj.sizes.keys())
dim = [alldims[a] for a in axis]
if any(xarray_obj.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return [d for d, s in xarray_obj.sizes.items() if s == 1]

elif isinstance(dim, str):
dim = [dim]
if isinstance(axis, int):
axis = (axis, )
if isinstance(axis, tuple):
if any(not isinstance(a, int) for a in axis):
raise ValueError('parameter `axis` must be int or tuple of int.')
alldims = list(xarray_obj.sizes.keys())
dim = [alldims[a] for a in axis]
if any(xarray_obj.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return dim


class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""

def squeeze(self, dim=None, drop=False, axis=None):
def squeeze(self, dim: Union[str, Tuple[str, ...], None] = None,
drop: bool = False,
axis: Union[int, Tuple[int, ...], None] = None):
"""Return a new object with squeezed data.

Parameters
Expand All @@ -255,8 +270,8 @@ def squeeze(self, dim=None, drop=False, axis=None):
drop : bool, optional
If ``drop=True``, drop squeezed coordinates instead of making them
scalar.
axis : int, optional
Select the dimension to squeeze. Added for compatibility reasons.
axis : None or int or tuple of int, optional
Like dim, but positional.

Returns
-------
Expand All @@ -271,7 +286,7 @@ def squeeze(self, dim=None, drop=False, axis=None):
dims = get_squeeze_dims(self, dim, axis)
return self.isel(drop=drop, **{d: 0 for d in dims})

def get_index(self, key):
def get_index(self, key: str) -> pd.Index:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key can actually be any hashable item here.

"""Get an index for a dimension, with fall-back to a default RangeIndex
"""
if key not in self.dims:
Expand All @@ -283,7 +298,7 @@ def get_index(self, key):
# need to ensure dtype=int64 in case range is empty on Python 2
return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64)

def _calc_assign_results(self, kwargs):
def _calc_assign_results(self, kwargs) -> SortedKeysDict:
results = SortedKeysDict()
for k, v in kwargs.items():
if callable(v):
Expand Down Expand Up @@ -372,7 +387,8 @@ def assign_attrs(self, *args, **kwargs):
out.attrs.update(*args, **kwargs)
return out

def pipe(self, func, *args, **kwargs):
def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]],
*args, **kwargs) -> T:
"""
Apply func(self, *args, **kwargs)

Expand Down Expand Up @@ -424,15 +440,14 @@ def pipe(self, func, *args, **kwargs):
if isinstance(func, tuple):
func, target = func
if target in kwargs:
msg = ('%s is both the pipe target and a keyword argument'
% target)
raise ValueError(msg)
raise ValueError('%s is both the pipe target and a keyword '
'argument' % target)
kwargs[target] = self
return func(*args, **kwargs)
else:
return func(self, *args, **kwargs)

def groupby(self, group, squeeze=True):
def groupby(self, group, squeeze: bool = True):
"""Returns a GroupBy object for performing grouped operations.

Parameters
Expand Down Expand Up @@ -478,8 +493,9 @@ def groupby(self, group, squeeze=True):
""" # noqa
return self._groupby_cls(self, group, squeeze=squeeze)

def groupby_bins(self, group, bins, right=True, labels=None, precision=3,
include_lowest=False, squeeze=True):
def groupby_bins(self, group, bins, right: bool = True, labels=None,
precision: int = 3, include_lowest: bool = False,
squeeze: bool = True):
"""Returns a GroupBy object for performing grouped operations.

Rather than using all unique values of `group`, the values are discretized
Expand Down Expand Up @@ -530,7 +546,9 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3,
'precision': precision,
'include_lowest': include_lowest})

def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs):
def rolling(self, dim: Optional[Mapping[str, int]] = None,
min_periods: Optional[int] = None, center: bool = False,
**dim_kwargs: int):
"""
Rolling window object.

Expand Down Expand Up @@ -590,8 +608,11 @@ def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs):
return self._rolling_cls(self, dim, min_periods=min_periods,
center=center)

def coarsen(self, dim=None, boundary='exact', side='left',
coord_func='mean', **dim_kwargs):
def coarsen(self, dim: Optional[Mapping[str, int]] = None,
boundary: str = 'exact',
side: Union[str, Mapping[str, str]] = 'left',
coord_func: str = 'mean',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like coord_func is more flexible than this. It looks like this should probably: Union[str, Mapping[Any, str]].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't you say that all dims should be Hashable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Mapping[Any, str] and Mapping[Hashable, str] are equivalent -- every mapping key must be hashable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there's absolutely nothing in Mapping saying that the keys must be hashable.

This is perfectly legit, and pretty useful too:

class AnyKeyDict(MutableMapping[K, V]):
    def __init__(self, *args: Tuple[K, V], **kwargs: V)::
       self.mapping = {}
       for k, v in args:
           self[k] = v
       for k, v in kwargs.items():
           self[k] = v

    def __setitem__(self, k: K, v: V) -> None:
        try:
            self.mapping[k] = v
        except TypeError:
            self.mapping[pickle.dumps(k)] = v

    ...

This can be useful in a few borderline scenarios; for example when de-serialising msgpack data from a non-Python producer, as in msgpack map keys can be anything. Also I'd love to see something like this:

class DataArray:
    def __setitem__(self, k, v):
        if isinstance(k, Mapping):
            view = self.sel(**k)
        else:
            view = self.isel(**k)
        view[:] = v

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, good point.

It turns out we actually do already support this for DataArray :)

>>> data = xarray.DataArray([0, 0, 0], dims='x')
>>> data[{'x': 1}] = 1
>>> data
<xarray.DataArray (x: 3)>
array([0, 1, 0])
Dimensions without coordinates: x

**dim_kwargs: int):
"""
Coarsen object.

Expand Down Expand Up @@ -650,8 +671,12 @@ def coarsen(self, dim=None, boundary='exact', side='left',
self, dim, boundary=boundary, side=side,
coord_func=coord_func)

def resample(self, indexer=None, skipna=None, closed=None, label=None,
base=0, keep_attrs=None, loffset=None, **indexer_kwargs):
def resample(self, indexer: Optional[Mapping[str, str]] = None,
skipna=None, closed: Optional[str] = None,
label: Optional[str] = None,
base: int = 0, keep_attrs: Optional[bool] = None,
loffset=None,
**indexer_kwargs: str):
"""Returns a Resample object for performing resampling operations.

Handles both downsampling and upsampling. If any intervals contain no
Expand Down Expand Up @@ -772,7 +797,7 @@ def resample(self, indexer=None, skipna=None, closed=None, label=None,

return resampler

def where(self, cond, other=dtypes.NA, drop=False):
def where(self, cond, other=dtypes.NA, drop: bool = False):
"""Filter elements from this object according to a condition.

This operation follows the normal broadcasting and alignment rules that
Expand Down Expand Up @@ -858,7 +883,7 @@ def where(self, cond, other=dtypes.NA, drop=False):

return ops.where_method(self, cond, other)

def close(self):
def close(self) -> None:
"""Close any files linked to this object
"""
if self._file_obj is not None:
Expand Down Expand Up @@ -921,7 +946,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self.close()


def full_like(other, fill_value, dtype=None):
def full_like(other, fill_value, dtype: Union[str, np.dtype, None] = None):
"""Return a new object with the same shape and type as a given object.

Parameters
Expand Down Expand Up @@ -961,7 +986,8 @@ def full_like(other, fill_value, dtype=None):
raise TypeError("Expected DataArray, Dataset, or Variable")


def _full_like_variable(other, fill_value, dtype=None):
def _full_like_variable(other, fill_value,
dtype: Union[str, np.dtype, None] = None):
"""Inner function of full_like, where other must be a variable
"""
from .variable import Variable
Expand All @@ -978,27 +1004,28 @@ def _full_like_variable(other, fill_value, dtype=None):
return Variable(dims=other.dims, data=data, attrs=other.attrs)


def zeros_like(other, dtype=None):
def zeros_like(other, dtype: Union[str, np.dtype, None] = None):
"""Shorthand for full_like(other, 0, dtype)
"""
return full_like(other, 0, dtype)


def ones_like(other, dtype=None):
def ones_like(other, dtype: Union[str, np.dtype, None] = None):
"""Shorthand for full_like(other, 1, dtype)
"""
return full_like(other, 1, dtype)


def is_np_datetime_like(dtype):
def is_np_datetime_like(dtype: Union[str, np.dtype]) -> bool:
"""Check if a dtype is a subclass of the numpy datetime types
"""
return (np.issubdtype(dtype, np.datetime64) or
np.issubdtype(dtype, np.timedelta64))


def _contains_cftime_datetimes(array):
"""Check if an array contains cftime.datetime objects"""
def _contains_cftime_datetimes(array) -> bool:
"""Check if an array contains cftime.datetime objects
"""
try:
from cftime import datetime as cftime_datetime
except ImportError:
Expand All @@ -1015,12 +1042,14 @@ def _contains_cftime_datetimes(array):
return False


def contains_cftime_datetimes(var):
"""Check if an xarray.Variable contains cftime.datetime objects"""
def contains_cftime_datetimes(var) -> bool:
"""Check if an xarray.Variable contains cftime.datetime objects
"""
return _contains_cftime_datetimes(var.data)


def _contains_datetime_like_objects(var):
def _contains_datetime_like_objects(var) -> bool:
"""Check if a variable contains datetime like objects (either
np.datetime64, np.timedelta64, or cftime.datetime)"""
np.datetime64, np.timedelta64, or cftime.datetime)
"""
return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var)
Loading