Skip to content

Commit 1dd3b7b

Browse files
authored
Implement sampling for v2 prior distributions (#461)
1 parent 041246e commit 1dd3b7b

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

petab/v1/distributions.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
508508
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
509509
return cauchy.ppf(q, loc=self._loc, scale=self._scale)
510510

511+
def _sample(self, shape=None) -> np.ndarray | float:
512+
return cauchy.rvs(loc=self._loc, scale=self._scale, size=shape)
513+
511514
@property
512515
def loc(self) -> float:
513516
"""The location parameter of the underlying distribution."""
@@ -541,14 +544,16 @@ class ChiSquare(Distribution):
541544

542545
def __init__(
543546
self,
544-
dof: int,
547+
dof: int | float,
545548
trunc: tuple[float, float] | None = None,
546549
log: bool | float = False,
547550
):
548-
if not dof.is_integer() or dof < 1:
549-
raise ValueError(
550-
f"`dof' must be a positive integer, but was `{dof}'."
551-
)
551+
if isinstance(dof, float):
552+
if not dof.is_integer() or dof < 1:
553+
raise ValueError(
554+
f"`dof' must be a positive integer, but was `{dof}'."
555+
)
556+
dof = int(dof)
552557

553558
self._dof = dof
554559
super().__init__(log=log, trunc=trunc)
@@ -565,6 +570,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
565570
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
566571
return chi2.ppf(q, df=self._dof)
567572

573+
def _sample(self, shape=None) -> np.ndarray | float:
574+
return chi2.rvs(df=self._dof, size=shape)
575+
568576
@property
569577
def dof(self) -> int:
570578
"""The degrees of freedom parameter."""
@@ -602,6 +610,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
602610
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
603611
return expon.ppf(q, scale=self._scale)
604612

613+
def _sample(self, shape=None) -> np.ndarray | float:
614+
return expon.rvs(scale=self._scale, size=shape)
615+
605616
@property
606617
def scale(self) -> float:
607618
"""The scale parameter of the underlying distribution."""
@@ -650,6 +661,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
650661
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
651662
return gamma.ppf(q, a=self._shape, scale=self._scale)
652663

664+
def _sample(self, shape=None) -> np.ndarray | float:
665+
return gamma.rvs(a=self._shape, scale=self._scale, size=shape)
666+
653667
@property
654668
def shape(self) -> float:
655669
"""The shape parameter of the underlying distribution."""
@@ -700,6 +714,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
700714
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
701715
return rayleigh.ppf(q, scale=self._scale)
702716

717+
def _sample(self, shape=None) -> np.ndarray | float:
718+
return rayleigh.rvs(scale=self._scale, size=shape)
719+
703720
@property
704721
def scale(self) -> float:
705722
"""The scale parameter of the underlying distribution."""

tests/v1/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
Normal(2, 1, log=10),
3535
Laplace(1, 2, trunc=(1, 2)),
3636
Laplace(1, 0.5, log=True, trunc=(0.5, 8)),
37+
Cauchy(2, 1),
38+
ChiSquare(4),
39+
Exponential(1),
40+
Gamma(3, 5),
41+
Rayleigh(3),
3742
],
3843
)
3944
def test_sample_matches_pdf(distribution):

0 commit comments

Comments
 (0)