Skip to content

Commit 56c2fe5

Browse files
committed
Add LogUniform distribution
Add an explicit LogUniform distribution class. The interpretation of the distribution parameter is different from the existing `Uniform(a, b, log=True)`. In PEtab v2, X ~ LogUniform(a, b) <=> ln(X) ~ Uniform(ln(a), ln(b)). However, in PEtab v1, a `parameterScaleUniform` prior for a parameterScale=log parameter is interpreted as ln(X) ~ Uniform(a, b).
1 parent a165108 commit 56c2fe5

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

petab/v1/distributions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"Normal",
3737
"Rayleigh",
3838
"Uniform",
39+
"LogUniform",
3940
]
4041

4142

@@ -382,6 +383,10 @@ class Uniform(Distribution):
382383
If ``False``, no transformation is applied.
383384
If a transformation is applied, the lower and upper bounds are the
384385
lower and upper bounds of the underlying uniform distribution.
386+
Note that this differs from the usual definition of a log-uniform
387+
distribution, where the logarithm of the variable is uniformly
388+
distributed between the logarithms of the bounds (see also
389+
:class:`LogUniform`).
385390
"""
386391

387392
def __init__(
@@ -411,6 +416,43 @@ def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
411416
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
412417

413418

419+
class LogUniform(Distribution):
420+
"""A log-uniform or reciprocal distribution.
421+
422+
A random variable is log-uniformly distributed between ``low`` and ``high``
423+
if its logarithm is uniformly distributed between ``log(low)`` and
424+
``log(high)``.
425+
426+
:param low: The lower bound of the distribution.
427+
:param high: The upper bound of the distribution.
428+
"""
429+
430+
def __init__(
431+
self,
432+
low: float,
433+
high: float,
434+
):
435+
self._logbase = np.exp(1)
436+
self._low = self._log(low)
437+
self._high = self._log(high)
438+
super().__init__(log=self._logbase)
439+
440+
def __repr__(self):
441+
return self._repr({"low": self._low, "high": self._high})
442+
443+
def _sample(self, shape=None) -> np.ndarray | float:
444+
return np.random.uniform(low=self._low, high=self._high, size=shape)
445+
446+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
447+
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
448+
449+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
450+
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
451+
452+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
453+
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
454+
455+
414456
class Laplace(Distribution):
415457
"""A (log-)Laplace distribution.
416458

petab/v2/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class PriorDistribution(str, Enum):
201201
PriorDistribution.LAPLACE: Laplace,
202202
PriorDistribution.LOG_LAPLACE: Laplace,
203203
PriorDistribution.LOG_NORMAL: Normal,
204-
PriorDistribution.LOG_UNIFORM: Uniform,
204+
PriorDistribution.LOG_UNIFORM: LogUniform,
205205
PriorDistribution.NORMAL: Normal,
206206
PriorDistribution.RAYLEIGH: Rayleigh,
207207
PriorDistribution.UNIFORM: Uniform,
@@ -1060,7 +1060,15 @@ def prior_dist(self) -> Distribution:
10601060
# `Uniform.__init__` does not accept the `trunc` parameter
10611061
low = max(self.prior_parameters[0], self.lb)
10621062
high = min(self.prior_parameters[1], self.ub)
1063-
return cls(low, high, log=log)
1063+
return cls(low, high)
1064+
1065+
if cls == LogUniform:
1066+
# Mind the different interpretation of distribution parameters for
1067+
# Uniform(..., log=True) and LogUniform!!
1068+
# `LogUniform.__init__` does not accept the `trunc` parameter
1069+
low = max(self.prior_parameters[0], self.lb)
1070+
high = min(self.prior_parameters[1], self.ub)
1071+
return cls(low, high)
10641072

10651073
return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub])
10661074

0 commit comments

Comments
 (0)