Skip to content

Commit e2d1e76

Browse files
committed
Clean up
1 parent 559d759 commit e2d1e76

File tree

2 files changed

+103
-73
lines changed

2 files changed

+103
-73
lines changed

src/pandas_openscm/index_manipulation.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,55 @@ def convert_index_to_category_index(pandas_obj: P) -> P:
5454
)
5555

5656

57+
def ensure_is_multiindex(index: pd.Index[Any] | pd.MultiIndex) -> pd.MultiIndex:
58+
"""
59+
Ensure that an index is a [pd.MultiIndex][pandas.MultiIndex]
60+
61+
Parameters
62+
----------
63+
index
64+
Index to check
65+
66+
Returns
67+
-------
68+
:
69+
Index, cast to [pd.MultiIndex][pandas.MultiIndex] if needed
70+
"""
71+
if isinstance(index, pd.MultiIndex):
72+
return index
73+
74+
return pd.MultiIndex.from_arrays([index.values], names=[index.name])
75+
76+
77+
def ensure_index_is_multiindex(pandas_obj: P, copy: bool = True) -> P:
78+
"""
79+
Ensure that the index of a pandas object is a [pd.MultiIndex][pandas.MultiIndex]
80+
81+
Parameters
82+
----------
83+
pandas_obj
84+
Object whose index we want to ensure is a [pd.MultiIndex][pandas.MultiIndex]
85+
86+
copy
87+
Should we copy `pandas_obj` before modifying the index?
88+
89+
Returns
90+
-------
91+
:
92+
`pandas_obj` with a [pd.MultiIndex][pandas.MultiIndex]
93+
"""
94+
# TODOO: accessor and tests
95+
if isinstance(pandas_obj.index, pd.MultiIndex):
96+
return pandas_obj
97+
98+
if copy:
99+
pandas_obj = pandas_obj.copy()
100+
101+
pandas_obj.index = ensure_is_multiindex(pandas_obj.index)
102+
103+
return pandas_obj
104+
105+
57106
def unify_index_levels(
58107
left: pd.MultiIndex, right: pd.MultiIndex
59108
) -> tuple[pd.MultiIndex, pd.MultiIndex]:
@@ -897,24 +946,3 @@ def set_index_levels_func(
897946
df.index = set_levels(df.index, levels_to_set=levels_to_set) # type: ignore
898947

899948
return df
900-
901-
902-
def ensure_is_multiindex(index: pd.Index | pd.MultiIndex) -> pd.MultiIndex:
903-
"""
904-
Ensure that an index is a [pd.MultiIndex][pandas.MultiIndex]
905-
906-
Parameters
907-
----------
908-
index
909-
Index to check
910-
911-
912-
Returns
913-
-------
914-
:
915-
Index, cast to [pd.MultiIndex][pandas.MultiIndex] if needed
916-
"""
917-
if isinstance(index, pd.MultiIndex):
918-
return index
919-
920-
return pd.MultiIndex.from_arrays([index.values], names=[index.name])

src/pandas_openscm/unit_conversion.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from pandas_openscm.exceptions import MissingOptionalDependencyError
1313
from pandas_openscm.index_manipulation import (
14+
ensure_index_is_multiindex,
1415
ensure_is_multiindex,
1516
set_index_levels_func,
1617
)
@@ -120,53 +121,44 @@ def convert_unit_from_target_series(
120121
sb temperature mK 1100.000 1200.000 1300.000
121122
body temperature degF 98.600 100.580 100.220
122123
"""
123-
if not isinstance(desired_units.index, pd.MultiIndex):
124-
desired_units.index = pd.MultiIndex.from_arrays(
125-
[desired_units.index.values], names=[desired_units.index.name]
126-
)
127-
128-
checker = df.index.droplevel(unit_level)
129-
if not isinstance(checker, pd.MultiIndex):
130-
checker = pd.MultiIndex.from_arrays([checker.values], names=[checker.name])
124+
desired_units = ensure_index_is_multiindex(desired_units)
131125

132-
missing_rows = checker.difference(desired_units.index)
126+
df_rows_checker = ensure_is_multiindex(df.index.droplevel(unit_level))
127+
missing_rows = df_rows_checker.difference( # type: ignore # pandas-stubs missing API
128+
desired_units.index.reorder_levels(df_rows_checker.names) # type: ignore # pandas-stubs missing API
129+
)
133130
if not missing_rows.empty:
134131
raise MissingDesiredUnitError(missing_rows)
135132

136-
df_reset_unit = df.reset_index(unit_level)
137-
if not isinstance(df_reset_unit.index, pd.MultiIndex):
138-
df_reset_unit.index = pd.MultiIndex.from_arrays(
139-
[df_reset_unit.index.values], names=[df_reset_unit.index.name]
140-
)
133+
df_reset_unit = ensure_index_is_multiindex(df.reset_index(unit_level), copy=False)
141134

142-
df_units = df_reset_unit[unit_level].rename("df_unit")
135+
df_units = df_reset_unit[unit_level]
143136

144-
desired_units_in_df = multi_index_lookup(desired_units, df_units.index).rename(
145-
"target_unit"
146-
)
137+
desired_units_in_df = multi_index_lookup(desired_units, df_units.index) # type: ignore # already checked that df_units.index is MultiIndex
147138

148139
# Don't need to align, pandas does that for us.
149140
# If you want to check, compare the below with
150-
# unit_map = pd.DataFrame([df_units_s, target_units_s.sample(frac=1)]).T
151-
unit_map = pd.DataFrame([df_units, desired_units_in_df]).T
152-
153-
unit_map_no_change = unit_map["df_unit"] == unit_map["target_unit"]
154-
if unit_map_no_change.all():
141+
# unit_map = pd.DataFrame([df_units, desired_units_in_df.sample(frac=1)]).T
142+
unit_map = pd.DataFrame(
143+
[df_units.rename("df_unit"), desired_units_in_df.rename("target_unit")]
144+
).T
145+
unit_changes = unit_map["df_unit"] != unit_map["target_unit"]
146+
if not unit_changes.any():
155147
# Already all in desired unit
156148
return df
157149

158150
if ur is None:
159151
try:
160152
import pint
161153

162-
ur = pint.get_application_registry()
154+
ur = pint.get_application_registry() # type: ignore # pint typing limited
163155
except ImportError:
164156
raise MissingOptionalDependencyError( # noqa: TRY003
165157
"convert_unit_from_target_series(..., ur=None, ...)", "pint"
166158
)
167159

168160
df_no_unit = df_reset_unit.drop(unit_level, axis="columns")
169-
for (df_unit, target_unit), conversion_df in unit_map[~unit_map_no_change].groupby(
161+
for (df_unit, target_unit), conversion_df in unit_map[unit_changes].groupby(
170162
["df_unit", "target_unit"]
171163
):
172164
to_alter_loc = multi_index_match(df_no_unit.index, conversion_df.index) # type: ignore
@@ -179,7 +171,6 @@ def convert_unit_from_target_series(
179171
new_units = (unit_map.reorder_levels(df_no_unit.index.names).loc[df_no_unit.index])[
180172
"target_unit"
181173
]
182-
183174
res = set_index_levels_func(df_no_unit, {unit_level: new_units}).reorder_levels(
184175
df.index.names
185176
)
@@ -304,10 +295,11 @@ def convert_unit(
304295
sb temperature K 1.100 1.200 1.300
305296
body temperature degF 98.600 100.580 100.220
306297
"""
307-
df_units_s = df.index.get_level_values(unit_level).to_series(
308-
index=df.index.droplevel(unit_level), name="df_unit"
298+
df_units_s = ensure_index_is_multiindex(
299+
df.index.get_level_values(unit_level).to_series(
300+
index=df.index.droplevel(unit_level), name="df_unit"
301+
)
309302
)
310-
df_units_s.index = ensure_is_multiindex(df_units_s.index)
311303

312304
# I don't love creating target_units_s in this function,
313305
# but it's basically a convenience function
@@ -320,10 +312,10 @@ def convert_unit(
320312
)
321313

322314
elif isinstance(desired_units, Mapping):
323-
desired_units_s = df_units_s.replace(desired_units)
315+
desired_units_s = df_units_s.replace(desired_units) # type: ignore # pandas-stubs missing Mapping option
324316

325-
elif isinstance(desired_units, pd.Series):
326-
desired_units.index = ensure_is_multiindex(desired_units.index)
317+
elif isinstance(desired_units, pd.Series): # type: ignore # isinstance confused by pd.Series without generic type annotation
318+
desired_units = ensure_index_is_multiindex(desired_units) # type: ignore # as above
327319

328320
missing = df_units_s.index.difference(desired_units.index)
329321
if missing.empty:
@@ -365,7 +357,7 @@ def convert_unit_like(
365357
target: pd.DataFrame,
366358
df_unit_level: str = "unit",
367359
target_unit_level: str | None = None,
368-
ur: pint.UnitRegistry | None = None,
360+
ur: pint.facets.PlainRegistry | None = None,
369361
) -> pd.DataFrame:
370362
"""
371363
Convert units to match another [pd.DataFrame][pandas.DataFrame]
@@ -452,34 +444,44 @@ def convert_unit_like(
452444
else:
453445
target_unit_level_use = target_unit_level
454446

455-
df_units_s = df.index.get_level_values(df_unit_level).to_series(
456-
index=df.index.droplevel(df_unit_level)
447+
df_units_s = ensure_index_is_multiindex(
448+
df.index.get_level_values(df_unit_level).to_series(
449+
index=df.index.droplevel(df_unit_level)
450+
)
457451
)
458452

459-
extra_index_levels_target = target.index.names.difference(
453+
extra_index_levels_target = target.index.names.difference( # type: ignore # pandas-stubs API out of date
460454
[*df.index.names, target_unit_level_use]
461-
) # type: ignore # pandas-stubs confused
455+
)
462456
if extra_index_levels_target:
463-
# Drop out the extra levels and see if the intended unit is unambiguous
464-
tmp = target.index.droplevel(extra_index_levels_target).drop_duplicates()
465-
target_units_s = tmp.get_level_values(target_unit_level_use).to_series(
466-
index=tmp.droplevel(target_unit_level_use)
457+
# Drop out the extra levels and duplicates,
458+
# then create the target units Series
459+
# (ambiguity in the result is handled later)
460+
target_index_without_extra_levels_and_dups = target.index.droplevel(
461+
extra_index_levels_target
462+
).drop_duplicates()
463+
target_units_s = target_index_without_extra_levels_and_dups.get_level_values(
464+
target_unit_level_use
465+
).to_series(
466+
index=target_index_without_extra_levels_and_dups.droplevel(
467+
target_unit_level_use
468+
)
467469
)
468470

469471
else:
470472
target_units_s = target.index.get_level_values(target_unit_level_use).to_series(
471473
index=target.index.droplevel(target_unit_level_use)
472474
)
473475

476+
target_units_s = ensure_index_is_multiindex(target_units_s)
477+
474478
ambiguous = target_units_s.index.duplicated(keep=False)
475479
if ambiguous.any():
476-
ambiguous_idx = target_units_s[ambiguous].index
477-
if not isinstance(ambiguous_idx, pd.MultiIndex):
478-
ambiguous_idx = pd.MultiIndex.from_arrays(
479-
[ambiguous_idx.values], names=[ambiguous_idx.name]
480-
)
480+
ambiguous_idx = target_units_s[ambiguous].index.remove_unused_levels()
481+
if not isinstance(target.index, pd.MultiIndex): # pragma: no cover
482+
# Should be unreachable, but just in case
483+
raise TypeError(type(target.index))
481484

482-
ambiguous_idx = ambiguous_idx.remove_unused_levels()
483485
ambiguous_drivers = target.index[multi_index_match(target.index, ambiguous_idx)]
484486

485487
msg = (
@@ -498,13 +500,13 @@ def convert_unit_like(
498500
raise AmbiguousTargetUnitError(msg)
499501

500502
target_units_s, _ = target_units_s.align(df_units_s)
501-
target_units_s.index = ensure_is_multiindex(target_units_s.index)
502-
df_units_s.index = ensure_is_multiindex(df_units_s.index)
503+
target_units_s = target_units_s.reorder_levels(df_units_s.index.names)
503504
if target_units_s.isnull().any():
504505
# Fill rows that don't get a spec with their existing units
505-
target_units_s = multi_index_lookup(target_units_s, df_units_s.index).fillna(
506-
df_units_s
507-
)
506+
target_units_s = multi_index_lookup(
507+
target_units_s,
508+
df_units_s.index, # type: ignore # checked that index is MultiIndex above
509+
).fillna(df_units_s)
508510

509511
res = convert_unit_from_target_series(
510512
df=df, desired_units=target_units_s, unit_level=df_unit_level, ur=ur

0 commit comments

Comments
 (0)