1010import pandas as pd
1111
1212from pandas_openscm .exceptions import MissingOptionalDependencyError
13- from pandas_openscm .index_manipulation import set_index_levels_func
13+ from pandas_openscm .index_manipulation import (
14+ ensure_is_multiindex ,
15+ set_index_levels_func ,
16+ )
1417from pandas_openscm .indexing import multi_index_lookup , multi_index_match
1518
1619if TYPE_CHECKING :
@@ -37,7 +40,7 @@ def __init__(self, missing_ts: pd.MultiIndex) -> None:
3740
3841def convert_unit_from_target_series (
3942 df : pd .DataFrame ,
40- desired_unit : pd .Series [str ],
43+ desired_units : pd .Series [str ],
4144 unit_level : str = "unit" ,
4245 ur : pint .facets .PlainRegistry | None = None ,
4346) -> pd .DataFrame :
@@ -53,7 +56,7 @@ def convert_unit_from_target_series(
5356 df
5457 [pd.DataFrame][pandas.DataFrame] whose units should be converted
5558
56- desired_unit
59+ desired_units
5760 Desired unit(s) for `df`
5861
5962 This must be a [pd.Series][pandas.Series]
@@ -75,7 +78,7 @@ def convert_unit_from_target_series(
7578 Raises
7679 ------
7780 AssertionError
78- `desired_unit `'s index does not contain all the rows in `df`
81+ `desired_units `'s index does not contain all the rows in `df`
7982
8083 MissingOptionalDependencyError
8184 `ur` is `None` and [pint](https://pint.readthedocs.io/) is not available.
@@ -99,7 +102,7 @@ def convert_unit_from_target_series(
99102 >>>
100103 >>> convert_unit_from_target_series(
101104 ... start,
102- ... desired_unit =pd.Series(
105+ ... desired_units =pd.Series(
103106 ... ["K", "mK", "degF"],
104107 ... index=pd.MultiIndex.from_tuples(
105108 ... (
@@ -117,21 +120,35 @@ def convert_unit_from_target_series(
117120 sb temperature mK 1100.000 1200.000 1300.000
118121 body temperature degF 98.600 100.580 100.220
119122 """
120- missing_rows = df .index .difference (desired_unit .index )
123+ if not isinstance (desired_units .index , pd .MultiIndex ):
124+ desired_units .index = pd .MultiIndex .from_arrays (
125+ [desired_units .index .values ], names = [desired_units .index .name ]
126+ )
127+
128+ checker = df .index .droplevel (unit_level )
129+ if not isinstance (checker , pd .MultiIndex ):
130+ checker = pd .MultiIndex .from_arrays ([checker .values ], names = [checker .name ])
131+
132+ missing_rows = checker .difference (desired_units .index )
121133 if not missing_rows .empty :
122134 raise MissingDesiredUnitError (missing_rows )
123135
124136 df_reset_unit = df .reset_index (unit_level )
137+ if not isinstance (df_reset_unit .index , pd .MultiIndex ):
138+ df_reset_unit .index = pd .MultiIndex .from_arrays (
139+ [df_reset_unit .index .values ], names = [df_reset_unit .index .name ]
140+ )
141+
125142 df_units = df_reset_unit [unit_level ].rename ("df_unit" )
126143
127- desired_unit_in_df = multi_index_lookup (desired_unit , df_units .index ).rename (
144+ desired_units_in_df = multi_index_lookup (desired_units , df_units .index ).rename (
128145 "target_unit"
129146 )
130147
131148 # Don't need to align, pandas does that for us.
132149 # If you want to check, compare the below with
133150 # unit_map = pd.DataFrame([df_units_s, target_units_s.sample(frac=1)]).T
134- unit_map = pd .DataFrame ([df_units , desired_unit_in_df ]).T
151+ unit_map = pd .DataFrame ([df_units , desired_units_in_df ]).T
135152
136153 unit_map_no_change = unit_map ["df_unit" ] == unit_map ["target_unit" ]
137154 if unit_map_no_change .all ():
@@ -172,7 +189,7 @@ def convert_unit_from_target_series(
172189
173190def convert_unit (
174191 df : pd .DataFrame ,
175- desired_unit : str | Mapping [str , str ] | pd .Series [str ],
192+ desired_units : str | Mapping [str , str ] | pd .Series [str ],
176193 unit_level : str = "unit" ,
177194 ur : pint .facets .PlainRegistry | None = None ,
178195) -> pd .DataFrame :
@@ -188,7 +205,7 @@ def convert_unit(
188205 df
189206 [pd.DataFrame][pandas.DataFrame] whose units should be converted
190207
191- desired_unit
208+ desired_units
192209 Desired unit(s) for `df`
193210
194211 If this is a string,
@@ -203,7 +220,7 @@ def convert_unit(
203220
204221 If this is a [pd.Series][pandas.Series],
205222 then it will be passed to [convert_unit_from_target_series][(m).]
206- after filling any rows in `df` that are not in `desired_unit `
223+ after filling any rows in `df` that are not in `desired_units `
207224 with the unit from `df` (i.e. unspecified rows are not converted).
208225
209226 For further details, see examples
@@ -290,34 +307,37 @@ def convert_unit(
290307 df_units_s = df .index .get_level_values (unit_level ).to_series (
291308 index = df .index .droplevel (unit_level ), name = "df_unit"
292309 )
310+ df_units_s .index = ensure_is_multiindex (df_units_s .index )
293311
294312 # I don't love creating target_units_s in this function,
295313 # but it's basically a convenience function
296314 # and the creation is the only thing that this function does,
297315 # hence I am ok with it.
298- if isinstance (desired_unit , str ):
299- desired_unit_s = pd .Series (
300- [desired_unit ] * df .shape [0 ],
316+ if isinstance (desired_units , str ):
317+ desired_units_s = pd .Series (
318+ [desired_units ] * df .shape [0 ],
301319 index = df_units_s .index ,
302320 )
303321
304- elif isinstance (desired_unit , Mapping ):
305- desired_unit_s = df_units_s .replace (desired_unit )
322+ elif isinstance (desired_units , Mapping ):
323+ desired_units_s = df_units_s .replace (desired_units )
324+
325+ elif isinstance (desired_units , pd .Series ):
326+ desired_units .index = ensure_is_multiindex (desired_units .index )
306327
307- elif isinstance (desired_unit , pd .Series ):
308- missing = df_units_s .index .difference (desired_unit .index )
328+ missing = df_units_s .index .difference (desired_units .index )
309329 if missing .empty :
310- desired_unit_s = desired_unit
330+ desired_units_s = desired_units
311331 else :
312- desired_unit_s = pd .concat (
313- [desired_unit , multi_index_lookup (df_units_s , missing )]
332+ desired_units_s = pd .concat (
333+ [desired_units , multi_index_lookup (df_units_s , missing )]
314334 )
315335
316336 else :
317- raise NotImplementedError (type (desired_unit ))
337+ raise NotImplementedError (type (desired_units ))
318338
319339 res = convert_unit_from_target_series (
320- df = df , desired_unit = desired_unit_s , unit_level = unit_level , ur = ur
340+ df = df , desired_units = desired_units_s , unit_level = unit_level , ur = ur
321341 )
322342
323343 return res
@@ -445,48 +465,49 @@ def convert_unit_like(
445465 target_units_s = tmp .get_level_values (target_unit_level_use ).to_series (
446466 index = tmp .droplevel (target_unit_level_use )
447467 )
448- ambiguous = target_units_s .index .duplicated (keep = False )
449- if ambiguous .any ():
450- ambiguous_idx = target_units_s [ambiguous ].index
451- if not isinstance (ambiguous_idx , pd .MultiIndex ):
452- ambiguous_idx = pd .MultiIndex .from_arrays (
453- [ambiguous_idx .values ], names = [ambiguous_idx .name ]
454- )
455-
456- ambiguous_idx = ambiguous_idx .remove_unused_levels ()
457- ambiguous_drivers = target .index [
458- multi_index_match (target .index , ambiguous_idx )
459- ]
460-
461- msg = (
462- f"`df` has { df .index .names = } . "
463- f"`target` has { target .index .names = } . "
464- "The index levels in `target` that are also in `df` are "
465- f"{ target_units_s .index .names } . "
466- "When we only look at these levels, the desired unit looks like:\n "
467- f"{ target_units_s } \n "
468- "The unit to use isn't unambiguous for the following metadata:\n "
469- f"{ target_units_s [ambiguous ]} \n "
470- "The drivers of this ambiguity "
471- "are the following metadata levels in `target`\n "
472- f"{ ambiguous_drivers } "
473- )
474- raise AmbiguousTargetUnitError (msg )
475468
476469 else :
477470 target_units_s = target .index .get_level_values (target_unit_level_use ).to_series (
478471 index = target .index .droplevel (target_unit_level_use )
479472 )
480473
474+ ambiguous = target_units_s .index .duplicated (keep = False )
475+ if ambiguous .any ():
476+ ambiguous_idx = target_units_s [ambiguous ].index
477+ if not isinstance (ambiguous_idx , pd .MultiIndex ):
478+ ambiguous_idx = pd .MultiIndex .from_arrays (
479+ [ambiguous_idx .values ], names = [ambiguous_idx .name ]
480+ )
481+
482+ ambiguous_idx = ambiguous_idx .remove_unused_levels ()
483+ ambiguous_drivers = target .index [multi_index_match (target .index , ambiguous_idx )]
484+
485+ msg = (
486+ f"`df` has { df .index .names = } . "
487+ f"`target` has { target .index .names = } . "
488+ "The index levels in `target` that are also in `df` are "
489+ f"{ target_units_s .index .names } . "
490+ "When we only look at these levels, the desired unit looks like:\n "
491+ f"{ target_units_s } \n "
492+ "The unit to use isn't unambiguous for the following metadata:\n "
493+ f"{ target_units_s [ambiguous ]} \n "
494+ "The drivers of this ambiguity "
495+ "are the following metadata levels in `target`\n "
496+ f"{ ambiguous_drivers } "
497+ )
498+ raise AmbiguousTargetUnitError (msg )
499+
481500 target_units_s , _ = target_units_s .align (df_units_s )
501+ target_units_s .index = ensure_is_multiindex (target_units_s .index )
502+ df_units_s .index = ensure_is_multiindex (df_units_s .index )
482503 if target_units_s .isnull ().any ():
483504 # Fill rows that don't get a spec with their existing units
484505 target_units_s = multi_index_lookup (target_units_s , df_units_s .index ).fillna (
485506 df_units_s
486507 )
487508
488509 res = convert_unit_from_target_series (
489- df = df , desired_unit = target_units_s , unit_level = df_unit_level , ur = ur
510+ df = df , desired_units = target_units_s , unit_level = df_unit_level , ur = ur
490511 )
491512
492513 return res
0 commit comments