Skip to content

Commit 45c29b3

Browse files
committed
Store problem configuration in Problem
Introduces Problem.config which contains the info from the PEtab yaml file. Sometimes it is convenient to have the original filenames around. Closes #324.
1 parent 0b77d7f commit 45c29b3

File tree

3 files changed

+98
-35
lines changed

3 files changed

+98
-35
lines changed

petab/v1/problem.py

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from warnings import warn
1111

1212
import pandas as pd
13+
from pydantic import AnyUrl, BaseModel, Field, RootModel
1314

1415
from . import (
1516
conditions,
@@ -78,6 +79,7 @@ def __init__(
7879
observable_df: pd.DataFrame = None,
7980
mapping_df: pd.DataFrame = None,
8081
extensions_config: dict = None,
82+
config: ProblemConfig = None,
8183
):
8284
self.condition_df: pd.DataFrame | None = condition_df
8385
self.measurement_df: pd.DataFrame | None = measurement_df
@@ -112,6 +114,7 @@ def __init__(
112114

113115
self.model: Model | None = model
114116
self.extensions_config = extensions_config or {}
117+
self.config = config
115118

116119
def __getattr__(self, name):
117120
# For backward-compatibility, allow access to SBML model related
@@ -251,21 +254,32 @@ def from_files(
251254
)
252255

253256
@staticmethod
254-
def from_yaml(yaml_config: dict | Path | str) -> Problem:
257+
def from_yaml(
258+
yaml_config: dict | Path | str, base_path: str | Path = None
259+
) -> Problem:
255260
"""
256261
Factory method to load model and tables as specified by YAML file.
257262
258263
Arguments:
259264
yaml_config: PEtab configuration as dictionary or YAML file name
265+
base_path: Base directory or URL to resolve relative paths
260266
"""
267+
# path to the yaml file
268+
filepath = None
269+
261270
if isinstance(yaml_config, Path):
262271
yaml_config = str(yaml_config)
263272

264-
get_path = lambda filename: filename # noqa: E731
265273
if isinstance(yaml_config, str):
266-
path_prefix = get_path_prefix(yaml_config)
274+
filepath = yaml_config
275+
if base_path is None:
276+
base_path = get_path_prefix(yaml_config)
267277
yaml_config = yaml.load_yaml(yaml_config)
268-
get_path = lambda filename: f"{path_prefix}/{filename}" # noqa: E731
278+
279+
def get_path(filename):
280+
if base_path is None:
281+
return filename
282+
return f"{base_path}/{filename}"
269283

270284
if yaml.is_composite_problem(yaml_config):
271285
raise ValueError(
@@ -289,59 +303,58 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
289303
DeprecationWarning,
290304
stacklevel=2,
291305
)
306+
config = ProblemConfig(
307+
**yaml_config, base_path=base_path, filepath=filepath
308+
)
309+
problem0 = config.problems[0]
310+
# currently required for handling PEtab v2 in here
311+
problem0_ = yaml_config["problems"][0]
292312

293-
problem0 = yaml_config["problems"][0]
294-
295-
if isinstance(yaml_config[PARAMETER_FILE], list):
313+
if isinstance(config.parameter_file, list):
296314
parameter_df = parameters.get_parameter_df(
297-
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
315+
[get_path(f) for f in config.parameter_file]
298316
)
299317
else:
300318
parameter_df = (
301-
parameters.get_parameter_df(
302-
get_path(yaml_config[PARAMETER_FILE])
303-
)
304-
if yaml_config[PARAMETER_FILE]
319+
parameters.get_parameter_df(get_path(config.parameter_file))
320+
if config.parameter_file
305321
else None
306322
)
307-
308-
if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]:
309-
if len(problem0[SBML_FILES]) > 1:
323+
if config.format_version.root in [1, "1", "1.0.0"]:
324+
if len(problem0.sbml_files) > 1:
310325
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
311326
raise NotImplementedError(
312327
"Support for multiple models is not yet implemented."
313328
)
314329

315330
model = (
316331
model_factory(
317-
get_path(problem0[SBML_FILES][0]),
332+
get_path(problem0.sbml_files[0]),
318333
MODEL_TYPE_SBML,
319334
model_id=None,
320335
)
321-
if problem0[SBML_FILES]
336+
if problem0.sbml_files
322337
else None
323338
)
324339
else:
325-
if len(problem0[MODEL_FILES]) > 1:
340+
if len(problem0_[MODEL_FILES]) > 1:
326341
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
327342
raise NotImplementedError(
328343
"Support for multiple models is not yet implemented."
329344
)
330-
if not problem0[MODEL_FILES]:
345+
if not problem0_[MODEL_FILES]:
331346
model = None
332347
else:
333348
model_id, model_info = next(
334-
iter(problem0[MODEL_FILES].items())
349+
iter(problem0_[MODEL_FILES].items())
335350
)
336351
model = model_factory(
337352
get_path(model_info[MODEL_LOCATION]),
338353
model_info[MODEL_LANGUAGE],
339354
model_id=model_id,
340355
)
341356

342-
measurement_files = [
343-
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
344-
]
357+
measurement_files = [get_path(f) for f in problem0.measurement_files]
345358
# If there are multiple tables, we will merge them
346359
measurement_df = (
347360
core.concat_tables(
@@ -351,9 +364,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
351364
else None
352365
)
353366

354-
condition_files = [
355-
get_path(f) for f in problem0.get(CONDITION_FILES, [])
356-
]
367+
condition_files = [get_path(f) for f in problem0.condition_files]
357368
# If there are multiple tables, we will merge them
358369
condition_df = (
359370
core.concat_tables(condition_files, conditions.get_condition_df)
@@ -362,7 +373,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
362373
)
363374

364375
visualization_files = [
365-
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
376+
get_path(f) for f in problem0.visualization_files
366377
]
367378
# If there are multiple tables, we will merge them
368379
visualization_df = (
@@ -371,17 +382,15 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
371382
else None
372383
)
373384

374-
observable_files = [
375-
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
376-
]
385+
observable_files = [get_path(f) for f in problem0.observable_files]
377386
# If there are multiple tables, we will merge them
378387
observable_df = (
379388
core.concat_tables(observable_files, observables.get_observable_df)
380389
if observable_files
381390
else None
382391
)
383392

384-
mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
393+
mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])]
385394
# If there are multiple tables, we will merge them
386395
mapping_df = (
387396
core.concat_tables(mapping_files, mapping.get_mapping_df)
@@ -398,6 +407,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
398407
visualization_df=visualization_df,
399408
mapping_df=mapping_df,
400409
extensions_config=yaml_config.get(EXTENSIONS, {}),
410+
config=config,
401411
)
402412

403413
@staticmethod
@@ -998,3 +1008,50 @@ def n_priors(self) -> int:
9981008
return 0
9991009

10001010
return self.parameter_df[OBJECTIVE_PRIOR_PARAMETERS].notna().sum()
1011+
1012+
1013+
class VersionNumber(RootModel):
1014+
root: str | int
1015+
1016+
1017+
class ListOfFiles(RootModel):
1018+
"""List of files."""
1019+
1020+
root: list[str | AnyUrl] = Field(..., description="List of files.")
1021+
1022+
def __iter__(self):
1023+
return iter(self.root)
1024+
1025+
def __len__(self):
1026+
return len(self.root)
1027+
1028+
def __getitem__(self, index):
1029+
return self.root[index]
1030+
1031+
1032+
class SubProblem(BaseModel):
1033+
"""A `problems` object in the PEtab problem configuration."""
1034+
1035+
sbml_files: ListOfFiles = []
1036+
measurement_files: ListOfFiles = []
1037+
condition_files: ListOfFiles = []
1038+
observable_files: ListOfFiles = []
1039+
visualization_files: ListOfFiles = []
1040+
1041+
1042+
class ProblemConfig(BaseModel):
1043+
"""The PEtab problem configuration."""
1044+
1045+
filepath: str | AnyUrl | None = Field(
1046+
None,
1047+
description="The path to the PEtab problem configuration.",
1048+
exclude=True,
1049+
)
1050+
base_path: str | AnyUrl | None = Field(
1051+
None,
1052+
description="The base path to resolve relative paths.",
1053+
exclude=True,
1054+
)
1055+
format_version: VersionNumber = 1
1056+
parameter_file: str | AnyUrl | None = None
1057+
problems: list[SubProblem] = []

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"pyyaml",
2323
"jsonschema",
2424
"antlr4-python3-runtime==4.13.1",
25+
"pydantic>=2.10",
2526
]
2627
license = {text = "MIT License"}
2728
authors = [

tests/v1/test_petab.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,11 +862,16 @@ def test_problem_from_yaml_v1_multiple_files():
862862
observables_df, Path(tmpdir, f"observables{i}.tsv")
863863
)
864864

865-
petab_problem = petab.Problem.from_yaml(yaml_path)
865+
petab_problem1 = petab.Problem.from_yaml(yaml_path)
866866

867-
assert petab_problem.measurement_df.shape[0] == 2
868-
assert petab_problem.observable_df.shape[0] == 2
869-
assert petab_problem.condition_df.shape[0] == 2
867+
# test that we can load the problem from a dict with a custom base path
868+
yaml_config = petab.v1.load_yaml(yaml_path)
869+
petab_problem2 = petab.Problem.from_yaml(yaml_config, base_path=tmpdir)
870+
871+
for petab_problem in (petab_problem1, petab_problem2):
872+
assert petab_problem.measurement_df.shape[0] == 2
873+
assert petab_problem.observable_df.shape[0] == 2
874+
assert petab_problem.condition_df.shape[0] == 2
870875

871876

872877
def test_get_required_parameters_for_parameter_table(petab_problem):

0 commit comments

Comments
 (0)