Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ._replace rather than reconstructing vars #5181

Merged
merged 2 commits into from
Apr 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __getitem__(self: VariableType, key) -> VariableType:

def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
"""Used by IndexVariable to return IndexVariable objects when possible."""
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
return self._replace(dims=dims, data=data)

def _getitem_with_mask(self, key, fill_value=dtypes.NA):
"""Index this Variable with -1 remapped to fill_value."""
Expand Down Expand Up @@ -977,8 +977,12 @@ def copy(self, deep=True, data=None):
return self._replace(data=data)

def _replace(
self, dims=_default, data=_default, attrs=_default, encoding=_default
) -> "Variable":
self: VariableType,
dims=_default,
data=_default,
attrs=_default,
encoding=_default,
) -> VariableType:
if dims is _default:
dims = copy.copy(self._dims)
if data is _default:
Expand Down Expand Up @@ -1081,7 +1085,7 @@ def chunk(self, chunks={}, name=None, lock=False):

data = da.from_array(data, chunks, name=name, lock=lock, **kwargs)

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)
return self._replace(data=data)

def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
"""
Expand Down Expand Up @@ -1205,7 +1209,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
# TODO: remove this once dask.array automatically aligns chunks
data = data.rechunk(self.data.chunks)

return type(self)(self.dims, data, self._attrs, fastpath=True)
return self._replace(data=data)

def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
"""
Expand Down Expand Up @@ -1364,7 +1368,7 @@ def _roll_one_dim(self, dim, count):
# TODO: remove this once dask.array automatically aligns chunks
data = data.rechunk(self.data.chunks)

return type(self)(self.dims, data, self._attrs, fastpath=True)
return self._replace(data=data)

def roll(self, shifts=None, **shifts_kwargs):
"""
Expand Down Expand Up @@ -1426,7 +1430,7 @@ def transpose(self, *dims) -> "Variable":
return self.copy(deep=False)

data = as_indexable(self._data).transpose(axes)
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
return self._replace(dims=dims, data=data)

@property
def T(self) -> "Variable":
Expand Down Expand Up @@ -2276,11 +2280,11 @@ def notnull(self, keep_attrs: bool = None):

@property
def real(self):
return type(self)(self.dims, self.data.real, self._attrs)
return self._replace(data=self.data.real)

@property
def imag(self):
return type(self)(self.dims, self.data.imag, self._attrs)
return self._replace(data=self.data.imag)

def __array_wrap__(self, obj, context=None):
return Variable(self.dims, obj)
Expand Down Expand Up @@ -2555,7 +2559,7 @@ def _finalize_indexing_result(self, dims, data):
# returns Variable rather than IndexVariable if multi-dimensional
return Variable(dims, data, self._attrs, self._encoding)
else:
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
return self._replace(dims=dims, data=data)

def __setitem__(self, key, value):
raise TypeError("%s values cannot be modified" % type(self).__name__)
Expand Down Expand Up @@ -2636,7 +2640,7 @@ def copy(self, deep=True, data=None):
data.shape, self.shape
)
)
return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)
return self._replace(data=data)

def equals(self, other, equiv=None):
# if equiv is specified, super up
Expand Down