Skip to content

Commit d1ceaf9

Browse files
committed
Pass unit test
1 parent 95f53d1 commit d1ceaf9

File tree

5 files changed

+149
-16
lines changed

5 files changed

+149
-16
lines changed

src/pandas_openscm/db/loading.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def load_data( # noqa: PLR0913
3333
backend_data: OpenSCMDBDataBackend,
3434
db_index: pd.DataFrame,
3535
db_file_map: pd.Series[Path], # type: ignore # pandas type hints confused about what they support
36+
db_dir: Path,
3637
selector: pd.Index[Any] | pd.MultiIndex | pix.selectors.Selector | None = None,
3738
out_columns_type: type | None = None,
3839
parallel_op_config: ParallelOpConfig | None = None,
@@ -53,6 +54,9 @@ def load_data( # noqa: PLR0913
5354
db_file_map
5455
File map of the database from which to load
5556
57+
db_dir
58+
The directory in which the database lives
59+
5660
selector
5761
Selector to use to choose the data to load
5862
@@ -97,7 +101,7 @@ def load_data( # noqa: PLR0913
97101
else:
98102
index_to_load = mi_loc(db_index, selector)
99103

100-
files_to_load = (Path(v) for v in db_file_map[index_to_load["file_id"].unique()])
104+
files_to_load = (db_dir / v for v in db_file_map[index_to_load["file_id"].unique()])
101105
loaded_l = load_data_files(
102106
files_to_load=files_to_load,
103107
backend_data=backend_data,

src/pandas_openscm/db/openscm_db.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
load_db_index,
2222
load_db_metadata,
2323
)
24+
from pandas_openscm.db.path_handling import DBPath
2425
from pandas_openscm.db.reader import OpenSCMDBReader
2526
from pandas_openscm.db.rewriting import make_move_plan, rewrite_files
2627
from pandas_openscm.db.saving import save_data
@@ -297,6 +298,30 @@ def from_gzipped_tar_archive(
297298
backend_data: OpenSCMDBDataBackend | None = None,
298299
backend_index: OpenSCMDBIndexBackend | None = None,
299300
) -> OpenSCMDB:
301+
"""
302+
Initialise from a gzipped tar archive
303+
304+
This also unpacks the files to disk
305+
306+
Parameters
307+
----------
308+
tar_archive
309+
Tar archive from which to initialise
310+
311+
db_dir
312+
Directory in which to unpack the database
313+
314+
backend_data
315+
Backend to use for handling the data
316+
317+
backend_index
318+
Backend to use for handling the index
319+
320+
Returns
321+
-------
322+
:
323+
Initialised database
324+
"""
300325
with tarfile.open(tar_archive, "r") as tar:
301326
for member in tar.getmembers():
302327
if not member.isreg():
@@ -317,7 +342,7 @@ def from_gzipped_tar_archive(
317342

318343
return res
319344

320-
def get_new_data_file_path(self, file_id: int) -> Path:
345+
def get_new_data_file_path(self, file_id: int) -> DBPath:
321346
"""
322347
Get the path in which to write a new data file
323348
@@ -329,7 +354,7 @@ def get_new_data_file_path(self, file_id: int) -> Path:
329354
Returns
330355
-------
331356
:
332-
File in which to write the new data
357+
Information about the path in which to write the new data
333358
334359
Raises
335360
------
@@ -341,7 +366,7 @@ def get_new_data_file_path(self, file_id: int) -> Path:
341366
if file_path.exists():
342367
raise FileExistsError(file_path)
343368

344-
return file_path
369+
return DBPath.from_abs_path_and_db_dir(abs=file_path, db_dir=self.db_dir)
345370

346371
def load( # noqa: PLR0913
347372
self,
@@ -421,6 +446,7 @@ def load( # noqa: PLR0913
421446
backend_data=self.backend_data,
422447
db_index=index,
423448
db_file_map=file_map,
449+
db_dir=self.db_dir,
424450
selector=selector,
425451
out_columns_type=out_columns_type,
426452
parallel_op_config=parallel_op_config,
@@ -738,6 +764,24 @@ def save( # noqa: PLR0913
738764
)
739765

740766
def to_gzipped_tar_archive(self, out_file: Path, mode: str = "w:gz") -> Path:
767+
"""
768+
Convert to a gzipped tar archive
769+
770+
Parameters
771+
----------
772+
out_file
773+
File in which to write the output
774+
775+
mode
776+
Mode to use to open `out_file`
777+
778+
Returns
779+
-------
780+
:
781+
Path to the gzipped tar archive
782+
783+
This is the same as `out_file`, but is returned for convenience.
784+
"""
741785
with tarfile.open(out_file, mode) as tar:
742786
tar.add(self.db_dir, arcname="db")
743787

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Functionality for handling paths
3+
4+
In order to make our databases portable,
5+
we need to be a bit smarter than just using raw paths.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from pathlib import Path
11+
from typing import Any
12+
13+
import attr
14+
from attrs import define, field
15+
16+
17+
@define
18+
class DBPath:
19+
"""
20+
Database-related path
21+
22+
Carries the information required to write paths with certainty
23+
and keep the database portable.
24+
"""
25+
26+
abs: Path
27+
"""The absolute path for the file"""
28+
29+
rel_db: Path = field()
30+
"""The path relative to the database's directory"""
31+
32+
@rel_db.validator
33+
def rel_db_validator(self, attribute: attr.Attribute[Any], value: Path) -> None:
34+
"""
35+
Validate the value of `rel_db`
36+
37+
Parameters
38+
----------
39+
attribute
40+
Attribute being set
41+
42+
value
43+
Value to use
44+
45+
Raises
46+
------
47+
AssertionError
48+
`value` is not within `self.abs`
49+
"""
50+
if not str(self.abs).endswith(str(value)):
51+
msg = f"{value} for {attribute.name} is not within {self.abs=}"
52+
raise AssertionError(msg)
53+
54+
@classmethod
55+
def from_abs_path_and_db_dir(cls, abs: Path, db_dir: Path) -> DBPath:
56+
"""
57+
Initialise from an absolute path and a database directory
58+
59+
Parameters
60+
----------
61+
abs
62+
Absolute path
63+
64+
db_dir
65+
Database directory
66+
67+
Returns
68+
-------
69+
:
70+
Initialised `DBPath`
71+
"""
72+
return cls(abs=abs, rel_db=abs.relative_to(db_dir))

src/pandas_openscm/db/saving.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from attrs import define
1616

1717
from pandas_openscm.db.interfaces import OpenSCMDBDataBackend, OpenSCMDBIndexBackend
18+
from pandas_openscm.db.path_handling import DBPath
1819
from pandas_openscm.index_manipulation import (
1920
unify_index_levels_check_index_types,
2021
)
@@ -63,7 +64,7 @@ def save_data( # noqa: PLR0913
6364
data: pd.DataFrame,
6465
*,
6566
backend_data: OpenSCMDBDataBackend,
66-
get_new_data_file_path: Callable[[int], Path],
67+
get_new_data_file_path: Callable[[int], DBPath],
6768
backend_index: OpenSCMDBIndexBackend,
6869
index_file: Path,
6970
file_map_file: Path,
@@ -84,8 +85,20 @@ def save_data( # noqa: PLR0913
8485
data
8586
Data to save
8687
87-
db
88-
Database in which to save the data
88+
backend_data
89+
Backend to use to save the data
90+
91+
get_new_data_file_path
92+
Callable which, given an integer, returns the path info for the new data file
93+
94+
backend_index
95+
Backend to use to save the index
96+
97+
index_file
98+
File in which to save the index
99+
100+
file_map_file
101+
File in which to save the file map
89102
90103
index_non_data
91104
Index that is already in the database but isn't related to data.
@@ -94,7 +107,7 @@ def save_data( # noqa: PLR0913
94107
before we write the database's index.
95108
96109
file_map_non_data
97-
File map that is already in the database but isn't related to data.
110+
File map that is already in the database but isn't related to `data`.
98111
99112
If supplied, this is combined with the file map generated for `data`
100113
before we write the database's file map.
@@ -179,9 +192,9 @@ def save_data( # noqa: PLR0913
179192
for increment, (_, df) in enumerate(grouper):
180193
file_id = min_file_id + increment
181194

182-
new_file_path = get_new_data_file_path(file_id)
195+
new_db_path = get_new_data_file_path(file_id)
183196

184-
file_map_out.loc[file_id] = new_file_path # type: ignore # pandas types confused about what they support
197+
file_map_out.loc[file_id] = new_db_path.rel_db # type: ignore # pandas types confused about what they support
185198
if index_non_data_unified_index is None:
186199
df_index_unified = df.index
187200
else:
@@ -202,7 +215,7 @@ def save_data( # noqa: PLR0913
202215
info=df,
203216
info_kind=DBFileType.DATA,
204217
backend=backend_data,
205-
save_path=new_file_path,
218+
save_path=new_db_path.abs,
206219
)
207220
)
208221

tests/integration/database/test_integration_database_portability.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
import pandas as pd
99
import pytest
1010

11-
from pandas_openscm.db import CSVDataBackend, CSVIndexBackend, OpenSCMDB
11+
from pandas_openscm.db import FeatherDataBackend, FeatherIndexBackend, OpenSCMDB
1212
from pandas_openscm.testing import assert_frame_alike
1313

1414

1515
@pytest.mark.parametrize(
1616
"backend_data_for_class_method, backend_index_for_class_method",
1717
(
1818
pytest.param(
19-
CSVDataBackend(),
20-
CSVIndexBackend(),
19+
FeatherDataBackend(),
20+
FeatherIndexBackend(),
2121
id="provided",
2222
),
2323
pytest.param(
@@ -39,8 +39,8 @@ def test_move_db(
3939

4040
db = OpenSCMDB(
4141
db_dir=initial_db_dir,
42-
backend_data=CSVDataBackend(),
43-
backend_index=CSVIndexBackend(),
42+
backend_data=FeatherDataBackend(),
43+
backend_index=FeatherIndexBackend(),
4444
)
4545

4646
df_timeseries_like = pd.DataFrame(

0 commit comments

Comments
 (0)