Skip to content

Commit 723c49f

Browse files
committed
Add accessor for set levels
1 parent 786ad53 commit 723c49f

File tree

5 files changed

+181
-80
lines changed

5 files changed

+181
-80
lines changed

src/pandas_openscm/accessors/series.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
from __future__ import annotations
66

7-
from collections.abc import Mapping
7+
from collections.abc import Collection, Mapping
88
from typing import TYPE_CHECKING, Any, Generic, TypeVar
99

1010
import pandas as pd
1111

12+
from pandas_openscm.index_manipulation import (
13+
set_index_levels_func,
14+
)
1215
from pandas_openscm.unit_conversion import convert_unit, convert_unit_like
1316

1417
if TYPE_CHECKING:
@@ -305,33 +308,33 @@ def convert_unit_like(
305308
# """
306309
# return mi_loc(self._df, locator)
307310

308-
# def set_index_levels(
309-
# self,
310-
# levels_to_set: dict[str, Any | Collection[Any]],
311-
# copy: bool = True,
312-
# ) -> pd.DataFrame:
313-
# """
314-
# Set the index levels
315-
#
316-
# Parameters
317-
# ----------
318-
# levels_to_set
319-
# Mapping of level names to values to set
320-
#
321-
# copy
322-
# Should the [pd.DataFrame][pandas.DataFrame] be copied before returning?
323-
#
324-
# Returns
325-
# -------
326-
# :
327-
# [pd.DataFrame][pandas.DataFrame] with updates applied to its index
328-
# """
329-
# return set_index_levels_func(
330-
# self._df,
331-
# levels_to_set=levels_to_set,
332-
# copy=copy,
333-
# )
334-
#
311+
def set_index_levels(
312+
self,
313+
levels_to_set: dict[str, Any | Collection[Any]],
314+
copy: bool = True,
315+
) -> S:
316+
"""
317+
Set the index levels
318+
319+
Parameters
320+
----------
321+
levels_to_set
322+
Mapping of level names to values to set
323+
324+
copy
325+
Should the [pd.Series][pandas.Series] be copied before returning?
326+
327+
Returns
328+
-------
329+
:
330+
[pd.Series][pandas.Series] with updates applied to its index
331+
"""
332+
return set_index_levels_func(
333+
self._series,
334+
levels_to_set=levels_to_set,
335+
copy=copy,
336+
)
337+
335338
# def to_category_index(self) -> pd.DataFrame:
336339
# """
337340
# Convert the index's values to categories

src/pandas_openscm/index_manipulation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,6 @@ def set_levels(
910910

911911

912912
def set_index_levels_func(
913-
# TODO: check support for series and add accessors
914913
pobj: P,
915914
levels_to_set: dict[str, Any | Collection[Any]],
916915
copy: bool = True,

src/pandas_openscm/testing.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import itertools
1212
from collections.abc import Collection
13-
from typing import TYPE_CHECKING, Any, cast
13+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
1414

1515
import numpy as np
1616
import pandas as pd
@@ -22,6 +22,8 @@
2222
if TYPE_CHECKING:
2323
import pytest
2424

25+
P = TypeVar("P", pd.DataFrame | pd.Series[Any])
26+
2527

2628
def get_db_data_backends() -> tuple[type[object], ...]:
2729
return tuple(v[1] for v in DATA_BACKENDS.options)
@@ -97,6 +99,76 @@ def assert_frame_alike(
9799
)
98100

99101

102+
@overload
103+
def convert_to_desired_type(
104+
pobj: pd.DataFrame, pobj_type: Literal["DataFrame"]
105+
) -> pd.DataFrame: ...
106+
107+
108+
@overload
109+
def convert_to_desired_type(
110+
pobj: pd.DataFrame, pobj_type: Literal["Series"]
111+
) -> pd.Series[Any]: ...
112+
113+
114+
def convert_to_desired_type(
115+
df: pd.DataFrame, pobj_type: Literal["DataFrame", "Series"]
116+
) -> pd.DataFrame | pd.Series[Any]:
117+
"""
118+
Convert a `df` to the desired type for testing
119+
120+
Parameters
121+
----------
122+
df
123+
[pd.DataFrame][pandas.DataFrame] to convert
124+
125+
pobj_type
126+
Type to convert to
127+
128+
If "DataFrame", then `df` is simply returned.
129+
If "Series", then the first column of `df` is returned.
130+
131+
Returns
132+
-------
133+
:
134+
`df` converted to the desired type
135+
"""
136+
if pobj_type == "DataFrame":
137+
return df
138+
139+
if pobj_type == "Series":
140+
res = df[df.columns[0]]
141+
return res
142+
143+
raise NotImplementedError(pobj_type)
144+
145+
146+
def check_result(res: P, exp: P) -> None:
147+
"""
148+
Check result in the case where it could be multiple types
149+
150+
Specifically, [pd.DataFrame][pandas.DataFrame]
151+
or [pd.Series][pandas.Series].
152+
153+
This is a thin wrapper, if you want specific functionality,
154+
use the underlying function.
155+
156+
Parameters
157+
----------
158+
res
159+
Result
160+
161+
exp
162+
Expected
163+
"""
164+
if isinstance(res, pd.DataFrame):
165+
assert_frame_alike(res, exp)
166+
elif isinstance(res, pd.Series):
167+
pd.testing.assert_series_equal(res, exp)
168+
else:
169+
raise NotImplementedError(type(res))
170+
171+
100172
def create_test_df(
101173
*,
102174
variables: Collection[tuple[str, str]],

tests/integration/index_manipulation/test_integration_index_manipulation_set_levels.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99
import pytest
1010

1111
from pandas_openscm.index_manipulation import set_index_levels_func, set_levels
12+
from pandas_openscm.testing import convert_to_desired_type
13+
14+
pobj_type = pytest.mark.parametrize(
15+
"pobj_type",
16+
("DataFrame", "Series"),
17+
)
18+
"""
19+
Parameterisation to use to check handling of both DataFrame and Series
20+
"""
1221

1322

1423
@pytest.mark.parametrize(
@@ -160,7 +169,8 @@ def test_set_levels(start, levels_to_set, exp):
160169
pd.testing.assert_index_equal(res, exp)
161170

162171

163-
def test_set_levels_with_a_dataframe():
172+
@pobj_type
173+
def test_set_levels_with_a_dataframe(pobj_type):
164174
start = pd.MultiIndex.from_tuples(
165175
[
166176
("sa", "va", "kg", 0),
@@ -170,11 +180,14 @@ def test_set_levels_with_a_dataframe():
170180
],
171181
names=["scenario", "variable", "unit", "run_id"],
172182
)
173-
start_df = pd.DataFrame(
174-
np.zeros((start.shape[0], 3)), columns=[2010, 2020, 2030], index=start
183+
start_pobj = convert_to_desired_type(
184+
pd.DataFrame(
185+
np.zeros((start.shape[0], 3)), columns=[2010, 2020, 2030], index=start
186+
),
187+
pobj_type,
175188
)
176189

177-
res = set_index_levels_func(start_df, levels_to_set={"new_variable": "test"})
190+
res = set_index_levels_func(start_pobj, levels_to_set={"new_variable": "test"})
178191

179192
exp = pd.MultiIndex.from_tuples(
180193
[
@@ -189,11 +202,13 @@ def test_set_levels_with_a_dataframe():
189202
pd.testing.assert_index_equal(res.index, exp)
190203

191204

192-
def test_set_levels_raises_type_error():
205+
@pobj_type
206+
def test_set_levels_raises_type_error(pobj_type):
193207
start = pd.DataFrame(
194208
np.arange(2 * 4).reshape((4, 2)),
195209
columns=[2010, 2020],
196210
)
211+
start = convert_to_desired_type(start, pobj_type)
197212

198213
levels_to_set = {"new_variable": "test"}
199214

@@ -221,7 +236,7 @@ def test_set_levels_raises_value_error():
221236
set_levels(start, levels_to_set=levels_to_set)
222237

223238

224-
def test_accessor(setup_pandas_accessors):
239+
def test_accessor_df(setup_pandas_accessors):
225240
start = pd.DataFrame(
226241
np.arange(2 * 4).reshape((4, 2)),
227242
columns=[2010, 2020],
@@ -262,3 +277,44 @@ def test_accessor(setup_pandas_accessors):
262277
# Test function too
263278
res = set_index_levels_func(start, levels_to_set=levels_to_set)
264279
pd.testing.assert_frame_equal(res, exp)
280+
281+
282+
def test_accessor_series(setup_pandas_accessors):
283+
start = pd.Series(
284+
np.arange(4),
285+
index=pd.MultiIndex.from_tuples(
286+
[
287+
("sa", "va", "kg", 0),
288+
("sb", "vb", "m", -1),
289+
("sa", "va", "kg", -2),
290+
("sa", "vb", "kg", 2),
291+
],
292+
names=["scenario", "variable", "unit", "run_id"],
293+
),
294+
)
295+
296+
levels_to_set = {
297+
"model_id": "674",
298+
"unit": ["t", "km", "g", "kg"],
299+
"scenario": 1,
300+
}
301+
302+
exp = pd.Series(
303+
start.values,
304+
index=pd.MultiIndex.from_tuples(
305+
[
306+
(1, "va", "t", 0, "674"),
307+
(1, "vb", "km", -1, "674"),
308+
(1, "va", "g", -2, "674"),
309+
(1, "vb", "kg", 2, "674"),
310+
],
311+
names=["scenario", "variable", "unit", "run_id", "model_id"],
312+
),
313+
)
314+
315+
res = start.openscm.set_index_levels(levels_to_set=levels_to_set)
316+
pd.testing.assert_series_equal(res, exp)
317+
318+
# Test function too
319+
res = set_index_levels_func(start, levels_to_set=levels_to_set)
320+
pd.testing.assert_series_equal(res, exp)

tests/integration/test_unit_conversion.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import re
88
import sys
9-
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
9+
from typing import TYPE_CHECKING, Any, TypeVar
1010
from unittest.mock import patch
1111

1212
import numpy as np
@@ -17,7 +17,12 @@
1717
from pandas_openscm.index_manipulation import (
1818
set_index_levels_func,
1919
)
20-
from pandas_openscm.testing import assert_frame_alike, create_test_df
20+
from pandas_openscm.testing import (
21+
assert_frame_alike,
22+
check_result,
23+
convert_to_desired_type,
24+
create_test_df,
25+
)
2126
from pandas_openscm.unit_conversion import (
2227
AmbiguousTargetUnitError,
2328
MissingDesiredUnitError,
@@ -29,6 +34,14 @@
2934
if TYPE_CHECKING:
3035
P = TypeVar("P", pd.DataFrame | pd.Series[Any])
3136

37+
pobj_type = pytest.mark.parametrize(
38+
"pobj_type",
39+
("DataFrame", "Series"),
40+
)
41+
"""
42+
Parameterisation to use to check handling of both DataFrame and Series
43+
"""
44+
3245
check_auto_index_casting_pobj = pytest.mark.parametrize(
3346
"only_two_index_levels_pobj",
3447
(
@@ -43,48 +56,6 @@
4356
This parameterisation ensures that we check this edge case.
4457
"""
4558

46-
pobj_type = pytest.mark.parametrize(
47-
"pobj_type",
48-
("DataFrame", "Series"),
49-
)
50-
"""
51-
Parameterisation to use to check handling of both DataFrame and Series
52-
"""
53-
54-
55-
@overload
56-
def convert_to_desired_type(
57-
pobj: pd.DataFrame, pobj_type: Literal["DataFrame"]
58-
) -> pd.DataFrame: ...
59-
60-
61-
@overload
62-
def convert_to_desired_type(
63-
pobj: pd.DataFrame, pobj_type: Literal["Series"]
64-
) -> pd.Series[Any]: ...
65-
66-
67-
def convert_to_desired_type(
68-
df: pd.DataFrame, pobj_type: Literal["DataFrame", "Series"]
69-
) -> pd.DataFrame | pd.Series[Any]:
70-
if pobj_type == "DataFrame":
71-
return df
72-
73-
if pobj_type == "Series":
74-
res = df[df.columns[0]]
75-
return res
76-
77-
raise NotImplementedError(pobj_type)
78-
79-
80-
def check_result(res: P, exp: P) -> None:
81-
if isinstance(res, pd.DataFrame):
82-
assert_frame_alike(res, exp)
83-
elif isinstance(res, pd.Series):
84-
pd.testing.assert_series_equal(res, exp)
85-
else:
86-
raise NotImplementedError(type(res))
87-
8859

8960
@pobj_type
9061
@check_auto_index_casting_pobj

0 commit comments

Comments
 (0)