Skip to content

Commit 2e623a0

Browse files
committed
Allow expand_dims() method to support inserting/broadcasting dimensions with size>1 (pydata#2757)
* Make using dim_kwargs for python 3.5 illegal -- a ValueError is thrown * dataset.expand_dims() method take dict like object where values represent length of dimensions or coordinates of dimesnsions * dataarray.expand_dims() method take dict like object where values represent length of dimensions or coordinates of dimesnsions * Add alternative option to passing a dict to the dim argument, which is now an optional kwarg, passing in each new dimension as its own kwarg * Add expand_dims enhancement from issue 2710 to whats-new.rst * Fix test_dataarray.TestDataArray.test_expand_dims_with_greater_dim_size tests to pass in python 3.5 using ordered dicts instead of regular dicts. This was needed because python 3.5 and earlier did not maintain insertion order for dicts * Restrict core logic to use 'dim' as a dict--it will be converted into a dict on entry if it is a str or a sequence of str * Don't cast dim values (coords) as a list since IndexVariable/Variable will internally convert it into a numpy.ndarray. So just use IndexVariable((k,), v) * TypeErrors should be raised for invalid input types, rather than ValueErrors. * Force 'dim' to be OrderedDict for python 3.5
1 parent a74ecd6 commit 2e623a0

File tree

5 files changed

+216
-36
lines changed

5 files changed

+216
-36
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,13 @@ Other enhancements
111111
See :ref:`compute.using_coordinates` for the detail.
112112
(:issue:`1332`)
113113
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
114-
- :py:meth:`pandas.Series.dropna` is now supported for a
115-
:py:class:`pandas.Series` indexed by a :py:class:`~xarray.CFTimeIndex`
116-
(:issue:`2688`). By `Spencer Clark <https://github.com/spencerkclark>`_.
117-
- :py:meth:`~xarray.cftime_range` now supports QuarterBegin and QuarterEnd offsets (:issue:`2663`).
118-
By `Jwen Fai Low <https://github.com/jwenfai>`_
119-
- :py:meth:`~xarray.open_dataset` now accepts a ``use_cftime`` argument, which
120-
can be used to require that ``cftime.datetime`` objects are always used, or
121-
never used when decoding dates encoded with a standard calendar. This can be
122-
used to ensure consistent date types are returned when using
123-
:py:meth:`~xarray.open_mfdataset` (:issue:`1263`) and/or to silence
124-
serialization warnings raised if dates from a standard calendar are found to
125-
be outside the :py:class:`pandas.Timestamp`-valid range (:issue:`2754`). By
126-
`Spencer Clark <https://github.com/spencerkclark>`_.
127-
114+
- Added :py:meth:`~xarray.Dataset.drop_dims` (:issue:`1949`).
115+
By `Kevin Squire <https://github.com/kmsquire>`_.
128116
- Allow ``expand_dims`` method to support inserting/broadcasting dimensions
129117
with size > 1. (:issue:`2710`)
130118
By `Martin Pletcher <https://github.com/pletchm>`_.
131119
`Spencer Clark <https://github.com/spencerkclark>`_.
132120

133-
- Added :py:meth:`~xarray.Dataset.drop_dims` (:issue:`1949`).
134-
By `Kevin Squire <https://github.com/kmsquire>`_.
135-
136121
Bug fixes
137122
~~~~~~~~~
138123

xarray/core/dataarray.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import sys
23
import warnings
34
from collections import OrderedDict
45

@@ -1138,7 +1139,7 @@ def swap_dims(self, dims_dict):
11381139
ds = self._to_temp_dataset().swap_dims(dims_dict)
11391140
return self._from_temp_dataset(ds)
11401141

1141-
def expand_dims(self, dim, axis=None):
1142+
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
11421143
"""Return a new object with an additional axis (or axes) inserted at
11431144
the corresponding position in the array shape.
11441145
@@ -1147,21 +1148,53 @@ def expand_dims(self, dim, axis=None):
11471148
11481149
Parameters
11491150
----------
1150-
dim : str or sequence of str.
1151+
dim : str, sequence of str, dict, or None
11511152
Dimensions to include on the new variable.
1152-
dimensions are inserted with length 1.
1153+
If provided as str or sequence of str, then dimensions are inserted
1154+
with length 1. If provided as a dict, then the keys are the new
1155+
dimensions and the values are either integers (giving the length of
1156+
the new dimensions) or sequence/ndarray (giving the coordinates of
1157+
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
1158+
dict-like, then it must be an ``OrderedDict``. This is to ensure
1159+
that the order in which the dims are given is maintained.
11531160
axis : integer, list (or tuple) of integers, or None
11541161
Axis position(s) where new axis is to be inserted (position(s) on
11551162
the result array). If a list (or tuple) of integers is passed,
11561163
multiple axes are inserted. In this case, dim arguments should be
11571164
same length list. If axis=None is passed, all the axes will be
11581165
inserted to the start of the result array.
1166+
**dim_kwargs : int or sequence/ndarray
1167+
The keywords are arbitrary dimensions being inserted and the values
1168+
are either the lengths of the new dims (if int is given), or their
1169+
coordinates. Note, this is an alternative to passing a dict to the
1170+
dim kwarg and will only be used if dim is None. **WARNING** for
1171+
python 3.5 ``dim_kwargs`` is not available.
11591172
11601173
Returns
11611174
-------
11621175
expanded : same type as caller
11631176
This object, but with an additional dimension(s).
11641177
"""
1178+
if isinstance(dim, int):
1179+
raise TypeError('dim should be str or sequence of strs or dict')
1180+
elif isinstance(dim, str):
1181+
dim = OrderedDict(((dim, 1),))
1182+
elif isinstance(dim, (list, tuple)):
1183+
if len(dim) != len(set(dim)):
1184+
raise ValueError('dims should not contain duplicate values.')
1185+
dim = OrderedDict(((d, 1) for d in dim))
1186+
1187+
# TODO: get rid of the below code block when python 3.5 is no longer
1188+
# supported.
1189+
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
1190+
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
1191+
if not python36_plus and not_ordereddict:
1192+
raise TypeError("dim must be an OrderedDict for python <3.6")
1193+
elif not python36_plus and dim_kwargs:
1194+
raise ValueError("dim_kwargs isn't available for python <3.6")
1195+
dim_kwargs = OrderedDict(dim_kwargs)
1196+
1197+
dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')
11651198
ds = self._to_temp_dataset().expand_dims(dim, axis)
11661199
return self._from_temp_dataset(ds)
11671200

xarray/core/dataset.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,7 +2329,7 @@ def swap_dims(self, dims_dict, inplace=None):
23292329
return self._replace_with_new_dims(variables, coord_names,
23302330
indexes=indexes, inplace=inplace)
23312331

2332-
def expand_dims(self, dim, axis=None):
2332+
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
23332333
"""Return a new object with an additional axis (or axes) inserted at
23342334
the corresponding position in the array shape.
23352335
@@ -2338,26 +2338,53 @@ def expand_dims(self, dim, axis=None):
23382338
23392339
Parameters
23402340
----------
2341-
dim : str or sequence of str.
2341+
dim : str, sequence of str, dict, or None
23422342
Dimensions to include on the new variable.
2343-
dimensions are inserted with length 1.
2343+
If provided as str or sequence of str, then dimensions are inserted
2344+
with length 1. If provided as a dict, then the keys are the new
2345+
dimensions and the values are either integers (giving the length of
2346+
the new dimensions) or sequence/ndarray (giving the coordinates of
2347+
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
2348+
dict-like, then it must be an ``OrderedDict``. This is to ensure
2349+
that the order in which the dims are given is maintained.
23442350
axis : integer, list (or tuple) of integers, or None
23452351
Axis position(s) where new axis is to be inserted (position(s) on
23462352
the result array). If a list (or tuple) of integers is passed,
23472353
multiple axes are inserted. In this case, dim arguments should be
2348-
the same length list. If axis=None is passed, all the axes will
2349-
be inserted to the start of the result array.
2354+
same length list. If axis=None is passed, all the axes will be
2355+
inserted to the start of the result array.
2356+
**dim_kwargs : int or sequence/ndarray
2357+
The keywords are arbitrary dimensions being inserted and the values
2358+
are either the lengths of the new dims (if int is given), or their
2359+
coordinates. Note, this is an alternative to passing a dict to the
2360+
dim kwarg and will only be used if dim is None. **WARNING** for
2361+
python 3.5 ``dim_kwargs`` is not available.
23502362
23512363
Returns
23522364
-------
23532365
expanded : same type as caller
23542366
This object, but with an additional dimension(s).
23552367
"""
23562368
if isinstance(dim, int):
2357-
raise ValueError('dim should be str or sequence of strs or dict')
2369+
raise TypeError('dim should be str or sequence of strs or dict')
2370+
elif isinstance(dim, str):
2371+
dim = OrderedDict(((dim, 1),))
2372+
elif isinstance(dim, (list, tuple)):
2373+
if len(dim) != len(set(dim)):
2374+
raise ValueError('dims should not contain duplicate values.')
2375+
dim = OrderedDict(((d, 1) for d in dim))
2376+
2377+
# TODO: get rid of the below code block when python 3.5 is no longer
2378+
# supported.
2379+
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
2380+
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
2381+
if not python36_plus and not_ordereddict:
2382+
raise TypeError("dim must be an OrderedDict for python <3.6")
2383+
elif not python36_plus and dim_kwargs:
2384+
raise ValueError("dim_kwargs isn't available for python <3.6")
2385+
2386+
dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')
23582387

2359-
if isinstance(dim, str):
2360-
dim = [dim]
23612388
if axis is not None and not isinstance(axis, (list, tuple)):
23622389
axis = [axis]
23632390

@@ -2376,10 +2403,24 @@ def expand_dims(self, dim, axis=None):
23762403
'{dim} already exists as coordinate or'
23772404
' variable name.'.format(dim=d))
23782405

2379-
if len(dim) != len(set(dim)):
2380-
raise ValueError('dims should not contain duplicate values.')
2381-
23822406
variables = OrderedDict()
2407+
# If dim is a dict, then ensure that the values are either integers
2408+
# or iterables.
2409+
for k, v in dim.items():
2410+
if hasattr(v, "__iter__"):
2411+
# If the value for the new dimension is an iterable, then
2412+
# save the coordinates to the variables dict, and set the
2413+
# value within the dim dict to the length of the iterable
2414+
# for later use.
2415+
variables[k] = xr.IndexVariable((k,), v)
2416+
self._coord_names.add(k)
2417+
dim[k] = len(list(v))
2418+
elif isinstance(v, int):
2419+
pass # Do nothing if the dimensions value is just an int
2420+
else:
2421+
raise TypeError('The value of new dimension {k} must be '
2422+
'an iterable or an int'.format(k=k))
2423+
23832424
for k, v in self._variables.items():
23842425
if k not in dim:
23852426
if k in self._coord_names: # Do not change coordinates
@@ -2400,11 +2441,13 @@ def expand_dims(self, dim, axis=None):
24002441
' values.')
24012442
# We need to sort them to make sure `axis` equals to the
24022443
# axis positions of the result array.
2403-
zip_axis_dim = sorted(zip(axis_pos, dim))
2444+
zip_axis_dim = sorted(zip(axis_pos, dim.items()))
2445+
2446+
all_dims = list(zip(v.dims, v.shape))
2447+
for d, c in zip_axis_dim:
2448+
all_dims.insert(d, c)
2449+
all_dims = OrderedDict(all_dims)
24042450

2405-
all_dims = list(v.dims)
2406-
for a, d in zip_axis_dim:
2407-
all_dims.insert(a, d)
24082451
variables[k] = v.set_dims(all_dims)
24092452
else:
24102453
# If dims includes a label of a non-dimension coordinate,

xarray/tests/test_dataarray.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import OrderedDict
44
from copy import deepcopy
55
from textwrap import dedent
6+
import sys
67

78
import numpy as np
89
import pandas as pd
@@ -1303,7 +1304,7 @@ def test_expand_dims_error(self):
13031304
coords={'x': np.linspace(0.0, 1.0, 3)},
13041305
attrs={'key': 'entry'})
13051306

1306-
with raises_regex(ValueError, 'dim should be str or'):
1307+
with raises_regex(TypeError, 'dim should be str or'):
13071308
array.expand_dims(0)
13081309
with raises_regex(ValueError, 'lengths of dim and axis'):
13091310
# dims and axis argument should be the same length
@@ -1328,6 +1329,16 @@ def test_expand_dims_error(self):
13281329
array.expand_dims(dim=['y', 'z'], axis=[2, -4])
13291330
array.expand_dims(dim=['y', 'z'], axis=[2, 3])
13301331

1332+
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
1333+
coords={'x': np.linspace(0.0, 1.0, 3)},
1334+
attrs={'key': 'entry'})
1335+
with pytest.raises(TypeError):
1336+
array.expand_dims(OrderedDict((("new_dim", 3.2),)))
1337+
1338+
# Attempt to use both dim and kwargs
1339+
with pytest.raises(ValueError):
1340+
array.expand_dims(OrderedDict((("d", 4),)), e=4)
1341+
13311342
def test_expand_dims(self):
13321343
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
13331344
coords={'x': np.linspace(0.0, 1.0, 3)},
@@ -1392,6 +1403,46 @@ def test_expand_dims_with_scalar_coordinate(self):
13921403
roundtripped = actual.squeeze(['z'], drop=False)
13931404
assert_identical(array, roundtripped)
13941405

1406+
def test_expand_dims_with_greater_dim_size(self):
1407+
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
1408+
coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0},
1409+
attrs={'key': 'entry'})
1410+
# For python 3.5 and earlier this has to be an ordered dict, to
1411+
# maintain insertion order.
1412+
actual = array.expand_dims(
1413+
OrderedDict((('y', 2), ('z', 1), ('dim_1', ['a', 'b', 'c']))))
1414+
1415+
expected_coords = OrderedDict((
1416+
('y', [0, 1]), ('z', [1.0]), ('dim_1', ['a', 'b', 'c']),
1417+
('x', np.linspace(0, 1, 3)), ('dim_0', range(4))))
1418+
expected = DataArray(array.values * np.ones([2, 1, 3, 3, 4]),
1419+
coords=expected_coords,
1420+
dims=list(expected_coords.keys()),
1421+
attrs={'key': 'entry'}
1422+
).drop(['y', 'dim_0'])
1423+
assert_identical(expected, actual)
1424+
1425+
# Test with kwargs instead of passing dict to dim arg.
1426+
1427+
# TODO: only the code under the if-statement is needed when python 3.5
1428+
# is no longer supported.
1429+
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
1430+
if python36_plus:
1431+
other_way = array.expand_dims(dim_1=['a', 'b', 'c'])
1432+
1433+
other_way_expected = DataArray(
1434+
array.values * np.ones([3, 3, 4]),
1435+
coords={'dim_1': ['a', 'b', 'c'],
1436+
'x': np.linspace(0, 1, 3),
1437+
'dim_0': range(4), 'z': 1.0},
1438+
dims=['dim_1', 'x', 'dim_0'],
1439+
attrs={'key': 'entry'}).drop('dim_0')
1440+
assert_identical(other_way_expected, other_way)
1441+
else:
1442+
# In python 3.5, using dim_kwargs should raise a ValueError.
1443+
with raises_regex(ValueError, "dim_kwargs isn't"):
1444+
array.expand_dims(e=["l", "m", "n"])
1445+
13951446
def test_set_index(self):
13961447
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names]
13971448
coords = {idx.name: ('x', idx) for idx in indexes}

xarray/tests/test_dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,27 @@ def test_expand_dims_error(self):
20332033
with raises_regex(ValueError, 'already exists'):
20342034
original.expand_dims(dim=['z'])
20352035

2036+
original = Dataset({'x': ('a', np.random.randn(3)),
2037+
'y': (['b', 'a'], np.random.randn(4, 3)),
2038+
'z': ('a', np.random.randn(3))},
2039+
coords={'a': np.linspace(0, 1, 3),
2040+
'b': np.linspace(0, 1, 4),
2041+
'c': np.linspace(0, 1, 5)},
2042+
attrs={'key': 'entry'})
2043+
with raises_regex(TypeError, 'value of new dimension'):
2044+
original.expand_dims(OrderedDict((("d", 3.2),)))
2045+
2046+
# TODO: only the code under the if-statement is needed when python 3.5
2047+
# is no longer supported.
2048+
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
2049+
if python36_plus:
2050+
with raises_regex(ValueError, 'both keyword and positional'):
2051+
original.expand_dims(OrderedDict((("d", 4),)), e=4)
2052+
else:
2053+
# In python 3.5, using dim_kwargs should raise a ValueError.
2054+
with raises_regex(ValueError, "dim_kwargs isn't"):
2055+
original.expand_dims(OrderedDict((("d", 4),)), e=4)
2056+
20362057
def test_expand_dims(self):
20372058
original = Dataset({'x': ('a', np.random.randn(3)),
20382059
'y': (['b', 'a'], np.random.randn(4, 3))},
@@ -2066,6 +2087,53 @@ def test_expand_dims(self):
20662087
roundtripped = actual.squeeze('z')
20672088
assert_identical(original, roundtripped)
20682089

2090+
# Test expanding one dimension to have size > 1 that doesn't have
2091+
# coordinates, and also expanding another dimension to have size > 1
2092+
# that DOES have coordinates.
2093+
actual = original.expand_dims(
2094+
OrderedDict((("d", 4), ("e", ["l", "m", "n"]))))
2095+
2096+
expected = Dataset(
2097+
{'x': xr.DataArray(original['x'].values * np.ones([4, 3, 3]),
2098+
coords=dict(d=range(4),
2099+
e=['l', 'm', 'n'],
2100+
a=np.linspace(0, 1, 3)),
2101+
dims=['d', 'e', 'a']).drop('d'),
2102+
'y': xr.DataArray(original['y'].values * np.ones([4, 3, 4, 3]),
2103+
coords=dict(d=range(4),
2104+
e=['l', 'm', 'n'],
2105+
b=np.linspace(0, 1, 4),
2106+
a=np.linspace(0, 1, 3)),
2107+
dims=['d', 'e', 'b', 'a']).drop('d')},
2108+
coords={'c': np.linspace(0, 1, 5)},
2109+
attrs={'key': 'entry'})
2110+
assert_identical(actual, expected)
2111+
2112+
# Test with kwargs instead of passing dict to dim arg.
2113+
2114+
# TODO: only the code under the if-statement is needed when python 3.5
2115+
# is no longer supported.
2116+
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
2117+
if python36_plus:
2118+
other_way = original.expand_dims(e=["l", "m", "n"])
2119+
other_way_expected = Dataset(
2120+
{'x': xr.DataArray(original['x'].values * np.ones([3, 3]),
2121+
coords=dict(e=['l', 'm', 'n'],
2122+
a=np.linspace(0, 1, 3)),
2123+
dims=['e', 'a']),
2124+
'y': xr.DataArray(original['y'].values * np.ones([3, 4, 3]),
2125+
coords=dict(e=['l', 'm', 'n'],
2126+
b=np.linspace(0, 1, 4),
2127+
a=np.linspace(0, 1, 3)),
2128+
dims=['e', 'b', 'a'])},
2129+
coords={'c': np.linspace(0, 1, 5)},
2130+
attrs={'key': 'entry'})
2131+
assert_identical(other_way_expected, other_way)
2132+
else:
2133+
# In python 3.5, using dim_kwargs should raise a ValueError.
2134+
with raises_regex(ValueError, "dim_kwargs isn't"):
2135+
original.expand_dims(e=["l", "m", "n"])
2136+
20692137
def test_set_index(self):
20702138
expected = create_test_multiindex()
20712139
mindex = expected['x'].to_index()

0 commit comments

Comments
 (0)