Skip to content

Commit 559d759

Browse files
committed
Add tests of annoying pandas MultiIndex casting
1 parent 0582d74 commit 559d759

File tree

4 files changed

+284
-127
lines changed

4 files changed

+284
-127
lines changed

docs/tutorials/unit-conversion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,10 @@
445445
desired_units
446446

447447
# %%
448-
# TODO: fix this
449448
# TODO: use accessor
450449
convert_unit(df_history, desired_units)
451450

452451
# %%
453-
# TODO: fix this
454452
# TODO: use accessor
455453
convert_unit_from_target_series(df_history, desired_units)
456454

src/pandas_openscm/index_manipulation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,24 @@ def set_index_levels_func(
897897
df.index = set_levels(df.index, levels_to_set=levels_to_set) # type: ignore
898898

899899
return df
900+
901+
902+
def ensure_is_multiindex(index: pd.Index | pd.MultiIndex) -> pd.MultiIndex:
903+
"""
904+
Ensure that an index is a [pd.MultiIndex][pandas.MultiIndex]
905+
906+
Parameters
907+
----------
908+
index
909+
Index to check
910+
911+
912+
Returns
913+
-------
914+
:
915+
Index, cast to [pd.MultiIndex][pandas.MultiIndex] if needed
916+
"""
917+
if isinstance(index, pd.MultiIndex):
918+
return index
919+
920+
return pd.MultiIndex.from_arrays([index.values], names=[index.name])

src/pandas_openscm/unit_conversion.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import pandas as pd
1111

1212
from pandas_openscm.exceptions import MissingOptionalDependencyError
13-
from pandas_openscm.index_manipulation import set_index_levels_func
13+
from pandas_openscm.index_manipulation import (
14+
ensure_is_multiindex,
15+
set_index_levels_func,
16+
)
1417
from pandas_openscm.indexing import multi_index_lookup, multi_index_match
1518

1619
if TYPE_CHECKING:
@@ -37,7 +40,7 @@ def __init__(self, missing_ts: pd.MultiIndex) -> None:
3740

3841
def convert_unit_from_target_series(
3942
df: pd.DataFrame,
40-
desired_unit: pd.Series[str],
43+
desired_units: pd.Series[str],
4144
unit_level: str = "unit",
4245
ur: pint.facets.PlainRegistry | None = None,
4346
) -> pd.DataFrame:
@@ -53,7 +56,7 @@ def convert_unit_from_target_series(
5356
df
5457
[pd.DataFrame][pandas.DataFrame] whose units should be converted
5558
56-
desired_unit
59+
desired_units
5760
Desired unit(s) for `df`
5861
5962
This must be a [pd.Series][pandas.Series]
@@ -75,7 +78,7 @@ def convert_unit_from_target_series(
7578
Raises
7679
------
7780
AssertionError
78-
`desired_unit`'s index does not contain all the rows in `df`
81+
`desired_units`'s index does not contain all the rows in `df`
7982
8083
MissingOptionalDependencyError
8184
`ur` is `None` and [pint](https://pint.readthedocs.io/) is not available.
@@ -99,7 +102,7 @@ def convert_unit_from_target_series(
99102
>>>
100103
>>> convert_unit_from_target_series(
101104
... start,
102-
... desired_unit=pd.Series(
105+
... desired_units=pd.Series(
103106
... ["K", "mK", "degF"],
104107
... index=pd.MultiIndex.from_tuples(
105108
... (
@@ -117,21 +120,35 @@ def convert_unit_from_target_series(
117120
sb temperature mK 1100.000 1200.000 1300.000
118121
body temperature degF 98.600 100.580 100.220
119122
"""
120-
missing_rows = df.index.difference(desired_unit.index)
123+
if not isinstance(desired_units.index, pd.MultiIndex):
124+
desired_units.index = pd.MultiIndex.from_arrays(
125+
[desired_units.index.values], names=[desired_units.index.name]
126+
)
127+
128+
checker = df.index.droplevel(unit_level)
129+
if not isinstance(checker, pd.MultiIndex):
130+
checker = pd.MultiIndex.from_arrays([checker.values], names=[checker.name])
131+
132+
missing_rows = checker.difference(desired_units.index)
121133
if not missing_rows.empty:
122134
raise MissingDesiredUnitError(missing_rows)
123135

124136
df_reset_unit = df.reset_index(unit_level)
137+
if not isinstance(df_reset_unit.index, pd.MultiIndex):
138+
df_reset_unit.index = pd.MultiIndex.from_arrays(
139+
[df_reset_unit.index.values], names=[df_reset_unit.index.name]
140+
)
141+
125142
df_units = df_reset_unit[unit_level].rename("df_unit")
126143

127-
desired_unit_in_df = multi_index_lookup(desired_unit, df_units.index).rename(
144+
desired_units_in_df = multi_index_lookup(desired_units, df_units.index).rename(
128145
"target_unit"
129146
)
130147

131148
# Don't need to align, pandas does that for us.
132149
# If you want to check, compare the below with
133150
# unit_map = pd.DataFrame([df_units_s, target_units_s.sample(frac=1)]).T
134-
unit_map = pd.DataFrame([df_units, desired_unit_in_df]).T
151+
unit_map = pd.DataFrame([df_units, desired_units_in_df]).T
135152

136153
unit_map_no_change = unit_map["df_unit"] == unit_map["target_unit"]
137154
if unit_map_no_change.all():
@@ -172,7 +189,7 @@ def convert_unit_from_target_series(
172189

173190
def convert_unit(
174191
df: pd.DataFrame,
175-
desired_unit: str | Mapping[str, str] | pd.Series[str],
192+
desired_units: str | Mapping[str, str] | pd.Series[str],
176193
unit_level: str = "unit",
177194
ur: pint.facets.PlainRegistry | None = None,
178195
) -> pd.DataFrame:
@@ -188,7 +205,7 @@ def convert_unit(
188205
df
189206
[pd.DataFrame][pandas.DataFrame] whose units should be converted
190207
191-
desired_unit
208+
desired_units
192209
Desired unit(s) for `df`
193210
194211
If this is a string,
@@ -203,7 +220,7 @@ def convert_unit(
203220
204221
If this is a [pd.Series][pandas.Series],
205222
then it will be passed to [convert_unit_from_target_series][(m).]
206-
after filling any rows in `df` that are not in `desired_unit`
223+
after filling any rows in `df` that are not in `desired_units`
207224
with the unit from `df` (i.e. unspecified rows are not converted).
208225
209226
For further details, see examples
@@ -290,34 +307,37 @@ def convert_unit(
290307
df_units_s = df.index.get_level_values(unit_level).to_series(
291308
index=df.index.droplevel(unit_level), name="df_unit"
292309
)
310+
df_units_s.index = ensure_is_multiindex(df_units_s.index)
293311

294312
# I don't love creating target_units_s in this function,
295313
# but it's basically a convenience function
296314
# and the creation is the only thing that this function does,
297315
# hence I am ok with it.
298-
if isinstance(desired_unit, str):
299-
desired_unit_s = pd.Series(
300-
[desired_unit] * df.shape[0],
316+
if isinstance(desired_units, str):
317+
desired_units_s = pd.Series(
318+
[desired_units] * df.shape[0],
301319
index=df_units_s.index,
302320
)
303321

304-
elif isinstance(desired_unit, Mapping):
305-
desired_unit_s = df_units_s.replace(desired_unit)
322+
elif isinstance(desired_units, Mapping):
323+
desired_units_s = df_units_s.replace(desired_units)
324+
325+
elif isinstance(desired_units, pd.Series):
326+
desired_units.index = ensure_is_multiindex(desired_units.index)
306327

307-
elif isinstance(desired_unit, pd.Series):
308-
missing = df_units_s.index.difference(desired_unit.index)
328+
missing = df_units_s.index.difference(desired_units.index)
309329
if missing.empty:
310-
desired_unit_s = desired_unit
330+
desired_units_s = desired_units
311331
else:
312-
desired_unit_s = pd.concat(
313-
[desired_unit, multi_index_lookup(df_units_s, missing)]
332+
desired_units_s = pd.concat(
333+
[desired_units, multi_index_lookup(df_units_s, missing)]
314334
)
315335

316336
else:
317-
raise NotImplementedError(type(desired_unit))
337+
raise NotImplementedError(type(desired_units))
318338

319339
res = convert_unit_from_target_series(
320-
df=df, desired_unit=desired_unit_s, unit_level=unit_level, ur=ur
340+
df=df, desired_units=desired_units_s, unit_level=unit_level, ur=ur
321341
)
322342

323343
return res
@@ -445,48 +465,49 @@ def convert_unit_like(
445465
target_units_s = tmp.get_level_values(target_unit_level_use).to_series(
446466
index=tmp.droplevel(target_unit_level_use)
447467
)
448-
ambiguous = target_units_s.index.duplicated(keep=False)
449-
if ambiguous.any():
450-
ambiguous_idx = target_units_s[ambiguous].index
451-
if not isinstance(ambiguous_idx, pd.MultiIndex):
452-
ambiguous_idx = pd.MultiIndex.from_arrays(
453-
[ambiguous_idx.values], names=[ambiguous_idx.name]
454-
)
455-
456-
ambiguous_idx = ambiguous_idx.remove_unused_levels()
457-
ambiguous_drivers = target.index[
458-
multi_index_match(target.index, ambiguous_idx)
459-
]
460-
461-
msg = (
462-
f"`df` has {df.index.names=}. "
463-
f"`target` has {target.index.names=}. "
464-
"The index levels in `target` that are also in `df` are "
465-
f"{target_units_s.index.names}. "
466-
"When we only look at these levels, the desired unit looks like:\n"
467-
f"{target_units_s}\n"
468-
"The unit to use isn't unambiguous for the following metadata:\n"
469-
f"{target_units_s[ambiguous]}\n"
470-
"The drivers of this ambiguity "
471-
"are the following metadata levels in `target`\n"
472-
f"{ambiguous_drivers}"
473-
)
474-
raise AmbiguousTargetUnitError(msg)
475468

476469
else:
477470
target_units_s = target.index.get_level_values(target_unit_level_use).to_series(
478471
index=target.index.droplevel(target_unit_level_use)
479472
)
480473

474+
ambiguous = target_units_s.index.duplicated(keep=False)
475+
if ambiguous.any():
476+
ambiguous_idx = target_units_s[ambiguous].index
477+
if not isinstance(ambiguous_idx, pd.MultiIndex):
478+
ambiguous_idx = pd.MultiIndex.from_arrays(
479+
[ambiguous_idx.values], names=[ambiguous_idx.name]
480+
)
481+
482+
ambiguous_idx = ambiguous_idx.remove_unused_levels()
483+
ambiguous_drivers = target.index[multi_index_match(target.index, ambiguous_idx)]
484+
485+
msg = (
486+
f"`df` has {df.index.names=}. "
487+
f"`target` has {target.index.names=}. "
488+
"The index levels in `target` that are also in `df` are "
489+
f"{target_units_s.index.names}. "
490+
"When we only look at these levels, the desired unit looks like:\n"
491+
f"{target_units_s}\n"
492+
"The unit to use isn't unambiguous for the following metadata:\n"
493+
f"{target_units_s[ambiguous]}\n"
494+
"The drivers of this ambiguity "
495+
"are the following metadata levels in `target`\n"
496+
f"{ambiguous_drivers}"
497+
)
498+
raise AmbiguousTargetUnitError(msg)
499+
481500
target_units_s, _ = target_units_s.align(df_units_s)
501+
target_units_s.index = ensure_is_multiindex(target_units_s.index)
502+
df_units_s.index = ensure_is_multiindex(df_units_s.index)
482503
if target_units_s.isnull().any():
483504
# Fill rows that don't get a spec with their existing units
484505
target_units_s = multi_index_lookup(target_units_s, df_units_s.index).fillna(
485506
df_units_s
486507
)
487508

488509
res = convert_unit_from_target_series(
489-
df=df, desired_unit=target_units_s, unit_level=df_unit_level, ur=ur
510+
df=df, desired_units=target_units_s, unit_level=df_unit_level, ur=ur
490511
)
491512

492513
return res

0 commit comments

Comments
 (0)