Skip to content

Commit d33586e

Browse files
author
ahuang11
committed
Fix based on sugggestion
1 parent e307041 commit d33586e

File tree

3 files changed

+86
-45
lines changed

3 files changed

+86
-45
lines changed

xarray/core/dataset.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7170,14 +7170,6 @@ def drop_duplicate_coords(
71707170
)
71717171
elif coord not in new.dims:
71727172
dims = new[coord].dims
7173-
for dim in dims:
7174-
if any(new.get_index(dim).duplicated()):
7175-
raise ValueError(
7176-
f"Cannot have duplicate dimension values in "
7177-
f"'{dim}' when dropping duplicate coordinate "
7178-
f"values for '{coord}' due to ambiguity "
7179-
f"in unstack."
7180-
)
71817173

71827174
# stack the coord's dimensions
71837175
tmp_dim = "__tmp_dim__"
@@ -7190,22 +7182,29 @@ def drop_duplicate_coords(
71907182
# can actually find duplicates
71917183
new[tmp_dim] = new[coord].values.ravel()
71927184

7185+
base_coord = coord
7186+
71937187
# replace coord with tmp_dim for use with get_index()
71947188
# as to not repeat the same call in the dimension clause
71957189
coord = tmp_dim
7196-
7197-
unstack_tmp_dim = True
71987190
else:
7199-
unstack_tmp_dim = False
7191+
base_coord = None
72007192

72017193
index = new.get_index(coord).duplicated(keep=keep)
72027194
new = new.isel({coord: ~index})
72037195

7204-
if unstack_tmp_dim:
7205-
# return everything to normal
7206-
new[coord] = stacked_coord_indices.isel({coord: ~index})
7207-
new = new.unstack(coord)
7196+
if base_coord is not None:
7197+
# remove tmp_dim
7198+
new = new.swap_dims({tmp_dim: base_coord}).drop(tmp_dim)
7199+
7200+
# get associated coordinates with the stacked dim
7201+
tmp_index = stacked_coord_indices.isel({coord: ~index}).indexes[tmp_dim]
72087202

7203+
# unpack the coordinates and add back to dataset
7204+
keys = tmp_index.names
7205+
values = list(zip(*tmp_index.values))
7206+
for key, value in dict(zip(keys, values)).items():
7207+
new.coords[key] = base_coord, list(value)
72097208
return new
72107209

72117210
def curvefit(

xarray/tests/test_dataarray.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7371,26 +7371,26 @@ def test_drop_duplicate_coords(keep):
73717371
da.coords["valid"] = (("init", "tau"), np.array([[8, 6, 6], [7, 7, 7]]))
73727372

73737373
if keep == "first":
7374-
data = [[1, 2], [4, np.nan]]
7375-
init = [0, 1]
7376-
tau = [1, 2]
7377-
valid = [[8.0, 6.0], [7.0, np.nan]]
7374+
data = [1, 2, 4]
7375+
init = [0, 0, 1]
7376+
tau = [1, 2, 1]
7377+
valid = [8, 6, 7]
73787378
elif keep == "last":
7379-
data = [[1, 3], [np.nan, 6]]
7380-
init = [0, 1]
7381-
tau = [1, 3]
7382-
valid = [[8.0, 6.0], [np.nan, 7]]
7379+
data = [1, 3, 6]
7380+
init = [0, 0, 1]
7381+
tau = [1, 3, 3]
7382+
valid = [8, 6, 7]
73837383
else:
7384-
data = [[1]]
7384+
data = [1]
73857385
init = [0]
73867386
tau = [1]
7387-
valid = [[8]]
7387+
valid = [8]
73887388

73897389
result = da.drop_duplicate_coords("valid", keep=keep)
73907390
expected = xr.DataArray(
73917391
data,
7392-
dims=["init", "tau"],
7393-
coords={"init": init, "tau": tau, "valid": (("init", "tau"), valid)},
7392+
dims="valid",
7393+
coords={"init": ("valid", init), "tau": ("valid", tau), "valid": valid},
73947394
)
73957395
assert_equal(expected, result)
73967396

@@ -7404,9 +7404,29 @@ def test_drop_duplicate_coords_duplicate_dims(keep):
74047404
)
74057405
da.coords["valid"] = (("init", "tau"), np.array([[8, 6, 6], [7, 7, 7]]))
74067406

7407-
with pytest.raises(ValueError):
7408-
da.drop_duplicate_coords("valid", keep=keep)
7407+
if keep == "first":
7408+
data = [1, 2, 4]
7409+
init = [0, 0, 0]
7410+
tau = [1, 2, 1]
7411+
valid = [8, 6, 7]
7412+
elif keep == "last":
7413+
data = [1, 3, 6]
7414+
init = [0, 0, 0]
7415+
tau = [1, 3, 3]
7416+
valid = [8, 6, 7]
7417+
else:
7418+
data = [1]
7419+
init = [0]
7420+
tau = [1]
7421+
valid = [8]
74097422

7423+
result = da.drop_duplicate_coords("valid", keep=keep)
7424+
expected = xr.DataArray(
7425+
data,
7426+
dims="valid",
7427+
coords={"init": ("valid", init), "tau": ("valid", tau), "valid": valid},
7428+
)
7429+
assert_equal(expected, result)
74107430

74117431
@pytest.mark.parametrize("keep", ["first", "last", False])
74127432
def test_drop_duplicate_coords_missing(keep):

xarray/tests/test_dataset.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6910,26 +6910,26 @@ def test_drop_duplicate_coords(keep):
69106910
ds.coords["valid"] = (("init", "tau"), np.array([[8, 6, 6], [7, 7, 7]]))
69116911

69126912
if keep == "first":
6913-
data = [[1, 2], [4, np.nan]]
6914-
init = [0, 1]
6915-
tau = [1, 2]
6916-
valid = [[8.0, 6.0], [7.0, np.nan]]
6913+
data = [1, 2, 4]
6914+
init = [0, 0, 1]
6915+
tau = [1, 2, 1]
6916+
valid = [8, 6, 7]
69176917
elif keep == "last":
6918-
data = [[1, 3], [np.nan, 6]]
6919-
init = [0, 1]
6920-
tau = [1, 3]
6921-
valid = [[8.0, 6.0], [np.nan, 7]]
6918+
data = [1, 3, 6]
6919+
init = [0, 0, 1]
6920+
tau = [1, 3, 3]
6921+
valid = [8, 6, 7]
69226922
else:
6923-
data = [[1]]
6923+
data = [1]
69246924
init = [0]
69256925
tau = [1]
6926-
valid = [[8]]
6926+
valid = [8]
69276927

69286928
result = ds.drop_duplicate_coords("valid", keep=keep)
69296929
expected = xr.DataArray(
69306930
data,
6931-
dims=["init", "tau"],
6932-
coords={"init": init, "tau": tau, "valid": (("init", "tau"), valid)},
6931+
dims="valid",
6932+
coords={"init": ("valid", init), "tau": ("valid", tau), "valid": valid},
69336933
name="test",
69346934
).to_dataset()
69356935
assert_equal(expected, result)
@@ -6938,14 +6938,36 @@ def test_drop_duplicate_coords(keep):
69386938
@pytest.mark.parametrize("keep", ["first", "last", False])
69396939
def test_drop_duplicate_coords_duplicate_dims(keep):
69406940
ds = xr.DataArray(
6941-
[["a", "b", "c"], ["d", "e", "f"]],
6942-
coords={"init": [0, 0], "tau": [1, 2, 3]}, # duplicate inits
6941+
[[1, 2, 3], [4, 5, 6]],
6942+
coords={"init": [0, 0], "tau": [1, 2, 3]},
69436943
dims=["init", "tau"],
69446944
).to_dataset(name="test")
69456945
ds.coords["valid"] = (("init", "tau"), np.array([[8, 6, 6], [7, 7, 7]]))
69466946

6947-
with pytest.raises(ValueError):
6948-
ds.drop_duplicate_coords("valid", keep=keep)
6947+
if keep == "first":
6948+
data = [1, 2, 4]
6949+
init = [0, 0, 0]
6950+
tau = [1, 2, 1]
6951+
valid = [8, 6, 7]
6952+
elif keep == "last":
6953+
data = [1, 3, 6]
6954+
init = [0, 0, 0]
6955+
tau = [1, 3, 3]
6956+
valid = [8, 6, 7]
6957+
else:
6958+
data = [1]
6959+
init = [0]
6960+
tau = [1]
6961+
valid = [8]
6962+
6963+
result = ds.drop_duplicate_coords("valid", keep=keep)
6964+
expected = xr.DataArray(
6965+
data,
6966+
dims="valid",
6967+
coords={"init": ("valid", init), "tau": ("valid", tau), "valid": valid},
6968+
name="test",
6969+
).to_dataset()
6970+
assert_equal(expected, result)
69496971

69506972

69516973
@pytest.mark.parametrize("keep", ["first", "last", False])

0 commit comments

Comments
 (0)