Skip to content

Commit 154f5b0

Browse files
committed
Add logprob submodule to API
1 parent 951fe52 commit 154f5b0

File tree

3 files changed

+265
-10
lines changed

3 files changed

+265
-10
lines changed

docs/source/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ API
1515
api/smc
1616
api/data
1717
api/ode
18+
api/logprob
1819
api/tuning
1920
api/math
2021
api/pytensorf

docs/source/api/logprob.rst

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
***********
2+
Probability
3+
***********
4+
5+
.. currentmodule:: pymc
6+
7+
.. autosummary::
8+
:toctree: generated/
9+
10+
logp
11+
logcdf
12+
icdf
13+
14+
Conditional probability
15+
-----------------------
16+
17+
.. currentmodule:: pymc.logprob
18+
19+
.. autosummary::
20+
:toctree: generated/
21+
22+
conditional_logp
23+
transformed_conditional_logp

pymc/logprob/basic.py

+241-10
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVari
9696
warnings.warn(
9797
f"RandomVariables {rvs_in_graph} were found in the derived graph. "
9898
"These variables are a clone and do not match the original ones on identity.\n"
99-
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
100-
"`logp(model.replace_rvs_by_values([rv])[0], value)`",
99+
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. "
100+
"For example: `logp(model.replace_rvs_by_values([rv])[0], value)`",
101101
stacklevel=3,
102102
)
103103

@@ -119,7 +119,89 @@ def _deprecate_warn_missing_rvs(warn_rvs, kwargs):
119119

120120

121121
def logp(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable:
122-
"""Return the log-probability graph of a Random Variable"""
122+
"""Create a graph for the log-probability of a random variable.
123+
124+
Parameters
125+
----------
126+
rv : TensorVariable
127+
value : tensor_like
128+
Should be the same type (shape and dtype) as the rv.
129+
warn_rvs : bool, default True
130+
Warn if RVs were found in the logp graph.
131+
This can happen when a variable has other other random variables as inputs.
132+
In that case, those random variables should be replaced by their respective values.
133+
`pymc.logprob.conditional_logp` can also be used as an alternative.
134+
135+
Returns
136+
-------
137+
logp : TensorVariable
138+
139+
Raises
140+
------
141+
RuntimeError
142+
If the logp cannot be derived.
143+
144+
Examples
145+
--------
146+
Create a compiled function that evaluates the logp of a variable
147+
148+
.. code-block:: python
149+
150+
import pymc as pm
151+
import pytensor.tensor as pt
152+
153+
mu = pt.scalar("mu")
154+
rv = pm.Normal.dist(mu, 1.0)
155+
156+
value = pt.scalar("value")
157+
rv_logp = pm.logp(rv, value)
158+
159+
# Use .eval() for debugging
160+
print(rv_logp.eval({value: 0.9, mu: 0.0})) # -1.32393853
161+
162+
# Compile a function for repeated evaluations
163+
rv_logp_fn = pm.compile_pymc([value, mu], rv_logp)
164+
print(rv_logp_fn(value=0.9, mu=0.0)) # -1.32393853
165+
166+
167+
Derive the graph for a transformation of a RandomVariable
168+
169+
.. code-block:: python
170+
171+
import pymc as pm
172+
import pytensor.tensor as pt
173+
174+
mu = pt.scalar("mu")
175+
rv = pm.Normal.dist(mu, 1.0)
176+
exp_rv = pt.exp(rv)
177+
178+
value = pt.scalar("value")
179+
exp_rv_logp = pm.logp(exp_rv, value)
180+
181+
# Use .eval() for debugging
182+
print(exp_rv_logp.eval({value: 0.9, mu: 0.0})) # -0.81912844
183+
184+
# Compile a function for repeated evaluations
185+
exp_rv_logp_fn = pm.compile_pymc([value, mu], exp_rv_logp)
186+
print(exp_rv_logp_fn(value=0.9, mu=0.0)) # -0.81912844
187+
188+
189+
Define a CustomDist logp
190+
191+
.. code-block:: python
192+
193+
import pymc as pm
194+
import pytensor.tensor as pt
195+
196+
def normal_logp(value, mu, sigma):
197+
return pm.logp(pm.Normal.dist(mu, sigma), value)
198+
199+
with pm.Model() as model:
200+
mu = pm.Normal("mu")
201+
sigma = pm.HalfNormal("sigma")
202+
pm.CustomDist("x", mu, sigma, logp=normal_logp)
203+
204+
"""
123205
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)
124206

125207
value = pt.as_tensor_variable(value, dtype=rv.dtype)
@@ -136,7 +218,88 @@ def logp(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
136218

137219

138220
def logcdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable:
139-
"""Create a graph for the log-CDF of a Random Variable."""
221+
"""Create a graph for the log-CDF of a random variable.
222+
223+
Parameters
224+
----------
225+
rv : TensorVariable
226+
value : tensor_like
227+
Should be the same type (shape and dtype) as the rv.
228+
warn_rvs : bool, default True
229+
Warn if RVs were found in the logcdf graph.
230+
This can happen when a variable has other random variables as inputs.
231+
In that case, those random variables should be replaced by their respective values.
232+
233+
Returns
234+
-------
235+
logp : TensorVariable
236+
237+
Raises
238+
------
239+
RuntimeError
240+
If the logcdf cannot be derived.
241+
242+
Examples
243+
--------
244+
Create a compiled function that evaluates the logcdf of a variable
245+
246+
.. code-block:: python
247+
248+
import pymc as pm
249+
import pytensor.tensor as pt
250+
251+
mu = pt.scalar("mu")
252+
rv = pm.Normal.dist(mu, 1.0)
253+
254+
value = pt.scalar("value")
255+
rv_logcdf = pm.logcdf(rv, value)
256+
257+
# Use .eval() for debugging
258+
print(rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.2034146
259+
260+
# Compile a function for repeated evaluations
261+
rv_logcdf_fn = pm.compile_pymc([value, mu], rv_logcdf)
262+
print(rv_logcdf_fn(value=0.9, mu=0.0)) # -0.2034146
263+
264+
265+
Derive the graph for a transformation of a RandomVariable
266+
267+
.. code-block:: python
268+
269+
import pymc as pm
270+
import pytensor.tensor as pt
271+
272+
mu = pt.scalar("mu")
273+
rv = pm.Normal.dist(mu, 1.0)
274+
exp_rv = pt.exp(rv)
275+
276+
value = pt.scalar("value")
277+
exp_rv_logcdf = pm.logcdf(exp_rv, value)
278+
279+
# Use .eval() for debugging
280+
print(exp_rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.78078813
281+
282+
# Compile a function for repeated evaluations
283+
exp_rv_logcdf_fn = pm.compile_pymc([value, mu], exp_rv_logcdf)
284+
print(exp_rv_logcdf_fn(value=0.9, mu=0.0)) # -0.78078813
285+
286+
287+
Define a CustomDist logcdf
288+
289+
.. code-block:: python
290+
291+
import pymc as pm
292+
import pytensor.tensor as pt
293+
294+
def normal_logcdf(value, mu, sigma):
295+
return pm.logp(pm.Normal.dist(mu, sigma), value)
296+
297+
with pm.Model() as model:
298+
mu = pm.Normal("mu")
299+
sigma = pm.HalfNormal("sigma")
300+
pm.CustomDist("x", mu, sigma, logcdf=normal_logcdf)
301+
302+
"""
140303
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)
141304
value = pt.as_tensor_variable(value, dtype=rv.dtype)
142305
try:
@@ -153,7 +316,72 @@ def logcdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Te
153316

154317

155318
def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable:
156-
"""Create a graph for the inverse CDF of a Random Variable."""
319+
"""Create a graph for the inverse CDF of a random variable.
320+
321+
Parameters
322+
----------
323+
rv : TensorVariable
324+
value : tensor_like
325+
Should be the same type (shape and dtype) as the rv.
326+
warn_rvs : bool, default True
327+
Warn if RVs were found in the icdf graph.
328+
This can happen when a variable has other random variables as inputs.
329+
In that case, those random variables should be replaced by their respective values.
330+
331+
Returns
332+
-------
333+
icdf : TensorVariable
334+
335+
Raises
336+
------
337+
RuntimeError
338+
If the icdf cannot be derived.
339+
340+
Examples
341+
--------
342+
Create a compiled function that evaluates the icdf of a variable
343+
344+
.. code-block:: python
345+
346+
import pymc as pm
347+
import pytensor.tensor as pt
348+
349+
mu = pt.scalar("mu")
350+
rv = pm.Normal.dist(mu, 1.0)
351+
352+
value = pt.scalar("value")
353+
rv_icdf = pm.icdf(rv, value)
354+
355+
# Use .eval() for debugging
356+
print(rv_icdf.eval({value: 0.9, mu: 0.0})) # 1.28155157
357+
358+
# Compile a function for repeated evaluations
359+
rv_icdf_fn = pm.compile_pymc([value, mu], rv_icdf)
360+
print(rv_icdf_fn(value=0.9, mu=0.0)) # 1.28155157
361+
362+
363+
Derive the graph for a transformation of a RandomVariable
364+
365+
.. code-block:: python
366+
367+
import pymc as pm
368+
import pytensor.tensor as pt
369+
370+
mu = pt.scalar("mu")
371+
rv = pm.Normal.dist(mu, 1.0)
372+
exp_rv = pt.exp(rv)
373+
374+
value = pt.scalar("value")
375+
exp_rv_icdf = pm.icdf(exp_rv, value)
376+
377+
# Use .eval() for debugging
378+
print(exp_rv_icdf.eval({value: 0.9, mu: 0.0})) # 3.60222448
379+
380+
# Compile a function for repeated evaluations
381+
exp_rv_icdf_fn = pm.compile_pymc([value, mu], exp_rv_icdf)
382+
print(exp_rv_icdf_fn(value=0.9, mu=0.0)) # 3.60222448
383+
384+
"""
157385
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)
158386
value = pt.as_tensor_variable(value, dtype="floatX")
159387
try:
@@ -208,7 +436,9 @@ def conditional_logp(
208436
If we create a value variable for ``Y_rv``, i.e. ``y_vv = pt.scalar("y")``,
209437
the graph of ``conditional_logp({Y_rv: y_vv})`` is equivalent to the
210438
conditional log-probability :math:`\log p(Y = y \mid \Sigma^2)`, with a stochastic
211-
``sigma2_rv``. If we specify a value variable for ``sigma2_rv``, i.e.
439+
``sigma2_rv``.
440+
441+
If we specify a value variable for ``sigma2_rv``, i.e.
212442
``s_vv = pt.scalar("s2")``, then ``conditional_logp({Y_rv: y_vv, sigma2_rv: s_vv})``
213443
yields the conditional log-probabilities of the two variables.
214444
The sum of the two terms gives their joint log-probability.
@@ -221,11 +451,11 @@ def conditional_logp(
221451
222452
Parameters
223453
----------
224-
rv_values
454+
rv_values: dict
225455
A ``dict`` of variables that maps stochastic elements
226456
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
227457
values in a log-probability.
228-
warn_rvs
458+
warn_rvs : bool, default True
229459
When ``True``, issue a warning when a `RandomVariable` is found in
230460
the logp graph and doesn't have a corresponding value variable specified in
231461
`rv_values`.
@@ -237,8 +467,9 @@ def conditional_logp(
237467
238468
Returns
239469
-------
240-
A ``dict`` that maps each value variable to the conditional log-probability term derived
241-
from the respective `RandomVariable`.
470+
values_to_logps: dict
471+
A ``dict`` that maps each value variable to the conditional log-probability term derived
472+
from the respective `RandomVariable`.
242473
243474
"""
244475
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)

0 commit comments

Comments
 (0)