Skip to content

Commit 489035a

Browse files
committed
v2: Startpoint sampling
* Add `Parameter.prior_dist` * Update `v1.distributions.__all__` * Implement startpoint sampling for `v2.Problem` supporting all new prior distributions
1 parent 3de2ce9 commit 489035a

File tree

4 files changed

+95
-12
lines changed

4 files changed

+95
-12
lines changed

petab/v1/distributions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919

2020
__all__ = [
2121
"Distribution",
22+
"Cauchy",
23+
"ChiSquare",
24+
"Exponential",
25+
"Gamma",
26+
"Laplace",
2227
"Normal",
28+
"Rayleigh",
2329
"Uniform",
24-
"Laplace",
2530
]
2631

2732

petab/v2/core.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from typing_extensions import Self
2727

28+
from ..v1.distributions import *
2829
from ..v1.lint import is_valid_identifier
2930
from ..v1.math import petab_math_str, sympify_petab
3031
from . import C, get_observable_df
@@ -150,6 +151,26 @@ class PriorDistribution(str, Enum):
150151
f"{set(C.PRIOR_DISTRIBUTIONS)} vs { {e.value for e in PriorDistribution} }"
151152
)
152153

154+
_prior_to_cls = {
155+
PriorDistribution.CAUCHY: Cauchy,
156+
PriorDistribution.CHI_SQUARED: ChiSquare,
157+
PriorDistribution.EXPONENTIAL: Exponential,
158+
PriorDistribution.GAMMA: Gamma,
159+
PriorDistribution.LAPLACE: Laplace,
160+
PriorDistribution.LOG10_NORMAL: Normal,
161+
PriorDistribution.LOG_LAPLACE: Laplace,
162+
PriorDistribution.LOG_NORMAL: Normal,
163+
PriorDistribution.LOG_UNIFORM: Uniform,
164+
PriorDistribution.NORMAL: Normal,
165+
PriorDistribution.RAYLEIGH: Rayleigh,
166+
PriorDistribution.UNIFORM: Uniform,
167+
}
168+
169+
assert not (_mismatch := set(PriorDistribution) ^ set(_prior_to_cls)), (
170+
"PriorDistribution enum does not match _prior_to_cls. "
171+
f"Mismatches: {_mismatch}"
172+
)
173+
153174

154175
class Observable(BaseModel):
155176
"""Observable definition."""
@@ -929,6 +950,37 @@ def _validate(self) -> Self:
929950

930951
return self
931952

953+
@property
954+
def prior_dist(self) -> Distribution:
955+
"""Get the pior distribution of the parameter."""
956+
if self.estimate is False:
957+
raise ValueError(f"Parameter `{self.id}' is not estimated.")
958+
959+
if self.prior_distribution is None:
960+
return Uniform(self.lb, self.ub)
961+
962+
if not (cls := _prior_to_cls.get(self.prior_distribution)):
963+
raise ValueError(
964+
f"Prior distribution `{self.prior_distribution}' not "
965+
"supported."
966+
)
967+
968+
if str(self.prior_distribution).startswith("log-"):
969+
log = True
970+
elif str(self.prior_distribution).startswith("log10-"):
971+
log = 10
972+
else:
973+
log = False
974+
975+
if cls == Exponential:
976+
if log is not False:
977+
raise ValueError(
978+
"Exponential distribution does not support log "
979+
"transformation."
980+
)
981+
return cls(*self.prior_parameters, trunc=[self.lb, self.ub])
982+
return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub])
983+
932984

933985
class ParameterTable(BaseModel):
934986
"""PEtab parameter table."""

petab/v2/problem.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Any
1414

15+
import numpy as np
1516
import pandas as pd
1617
import sympy as sp
1718
from pydantic import AnyUrl, BaseModel, Field
@@ -22,10 +23,10 @@
2223
observables,
2324
parameter_mapping,
2425
parameters,
25-
sampling,
2626
yaml,
2727
)
2828
from ..v1.core import concat_tables, get_visualization_df
29+
from ..v1.distributions import Distribution
2930
from ..v1.models.model import Model, model_factory
3031
from ..v1.yaml import get_path_prefix
3132
from ..v2.C import * # noqa: F403
@@ -726,24 +727,29 @@ def get_optimization_to_simulation_parameter_mapping(self, **kwargs):
726727
)
727728
)
728729

729-
def sample_parameter_startpoints(self, n_starts: int = 100, **kwargs):
730-
"""Create 2D array with starting points for optimization
730+
def get_priors(self) -> dict[str, Distribution]:
731+
"""Get prior distributions.
731732
732-
See :py:func:`petab.sample_parameter_startpoints`.
733+
:returns: The prior distributions for the estimated parameters.
733734
"""
734-
return sampling.sample_parameter_startpoints(
735-
self.parameter_df, n_starts=n_starts, **kwargs
736-
)
735+
return {
736+
p.id: p.prior_dist
737+
for p in self.parameter_table.parameters
738+
if p.estimate
739+
}
740+
741+
def sample_parameter_startpoints(self, n_starts: int = 100, **kwargs):
742+
"""Create 2D array with starting points for optimization"""
743+
priors = self.get_priors()
744+
return np.vstack([p.sample(n_starts) for p in priors.values()]).T
737745

738746
def sample_parameter_startpoints_dict(
739747
self, n_starts: int = 100
740748
) -> list[dict[str, float]]:
741749
"""Create dictionaries with starting points for optimization
742750
743-
See also :py:func:`petab.sample_parameter_startpoints`.
744-
745-
Returns:
746-
A list of dictionaries with parameter IDs mapping to samples
751+
:returns:
752+
A list of dictionaries with parameter IDs mapping to sampled
747753
parameter values.
748754
"""
749755
return [

tests/v2/test_problem.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TARGET_VALUE,
2424
UPPER_BOUND,
2525
)
26+
from petab.v2.core import *
2627

2728

2829
def test_load_remote():
@@ -170,3 +171,22 @@ def test_modify_problem():
170171
}
171172
).set_index([PETAB_ENTITY_ID])
172173
assert_frame_equal(problem.mapping_df, exp_mapping_df, check_dtype=False)
174+
175+
176+
def test_sample_startpoint():
177+
"""Test startpoint sampling."""
178+
problem = Problem()
179+
problem += Parameter(id="p1", estimate=True, lb=1, ub=2)
180+
problem += Parameter(
181+
id="p2",
182+
estimate=True,
183+
lb=2,
184+
ub=3,
185+
prior_distribution="normal",
186+
prior_parameters=[2.5, 0.5],
187+
)
188+
problem += Parameter(id="p3", estimate=False, nominal_value=1)
189+
190+
n_starts = 10
191+
sp = problem.sample_parameter_startpoints(n_starts=n_starts)
192+
assert sp.shape == (n_starts, 2)

0 commit comments

Comments
 (0)