@@ -96,8 +96,8 @@ def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVari
96
96
warnings .warn (
97
97
f"RandomVariables { rvs_in_graph } were found in the derived graph. "
98
98
"These variables are a clone and do not match the original ones on identity.\n "
99
- "If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
100
- "`logp(model.replace_rvs_by_values([rv])[0], value)`" ,
99
+ "If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. "
100
+ "For example: `logp(model.replace_rvs_by_values([rv])[0], value)`" ,
101
101
stacklevel = 3 ,
102
102
)
103
103
@@ -119,7 +119,89 @@ def _deprecate_warn_missing_rvs(warn_rvs, kwargs):
119
119
120
120
121
121
def logp (rv : TensorVariable , value : TensorLike , warn_rvs = None , ** kwargs ) -> TensorVariable :
122
- """Return the log-probability graph of a Random Variable"""
122
+ """Create a graph for the log-probability of a random variable.
123
+
124
+ Parameters
125
+ ----------
126
+ rv : TensorVariable
127
+ value : tensor_like
128
+ Should be the same type (shape and dtype) as the rv.
129
+ warn_rvs : bool, default True
130
+ Warn if RVs were found in the logp graph.
131
+ This can happen when a variable has other other random variables as inputs.
132
+ In that case, those random variables should be replaced by their respective values.
133
+ `pymc.logprob.conditional_logp` can also be used as an alternative.
134
+
135
+ Returns
136
+ -------
137
+ logp : TensorVariable
138
+
139
+ Raises
140
+ ------
141
+ RuntimeError
142
+ If the logp cannot be derived.
143
+
144
+ Examples
145
+ --------
146
+ Create a compiled function that evaluates the logp of a variable
147
+
148
+ .. code-block:: python
149
+
150
+ import pymc as pm
151
+ import pytensor.tensor as pt
152
+
153
+ mu = pt.scalar("mu")
154
+ rv = pm.Normal.dist(mu, 1.0)
155
+
156
+ value = pt.scalar("value")
157
+ rv_logp = pm.logp(rv, value)
158
+
159
+ # Use .eval() for debugging
160
+ print(rv_logp.eval({value: 0.9, mu: 0.0})) # -1.32393853
161
+
162
+ # Compile a function for repeated evaluations
163
+ rv_logp_fn = pm.compile_pymc([value, mu], rv_logp)
164
+ print(rv_logp_fn(value=0.9, mu=0.0)) # -1.32393853
165
+
166
+
167
+ Derive the graph for a transformation of a RandomVariable
168
+
169
+ .. code-block:: python
170
+
171
+ import pymc as pm
172
+ import pytensor.tensor as pt
173
+
174
+ mu = pt.scalar("mu")
175
+ rv = pm.Normal.dist(mu, 1.0)
176
+ exp_rv = pt.exp(rv)
177
+
178
+ value = pt.scalar("value")
179
+ exp_rv_logp = pm.logp(exp_rv, value)
180
+
181
+ # Use .eval() for debugging
182
+ print(exp_rv_logp.eval({value: 0.9, mu: 0.0})) # -0.81912844
183
+
184
+ # Compile a function for repeated evaluations
185
+ exp_rv_logp_fn = pm.compile_pymc([value, mu], exp_rv_logp)
186
+ print(exp_rv_logp_fn(value=0.9, mu=0.0)) # -0.81912844
187
+
188
+
189
+ Define a CustomDist logp
190
+
191
+ .. code-block:: python
192
+
193
+ import pymc as pm
194
+ import pytensor.tensor as pt
195
+
196
+ def normal_logp(value, mu, sigma):
197
+ return pm.logp(pm.Normal.dist(mu, sigma), value)
198
+
199
+ with pm.Model() as model:
200
+ mu = pm.Normal("mu")
201
+ sigma = pm.HalfNormal("sigma")
202
+ pm.CustomDist("x", mu, sigma, logp=normal_logp)
203
+
204
+ """
123
205
warn_rvs , kwargs = _deprecate_warn_missing_rvs (warn_rvs , kwargs )
124
206
125
207
value = pt .as_tensor_variable (value , dtype = rv .dtype )
@@ -136,7 +218,88 @@ def logp(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
136
218
137
219
138
220
def logcdf (rv : TensorVariable , value : TensorLike , warn_rvs = None , ** kwargs ) -> TensorVariable :
139
- """Create a graph for the log-CDF of a Random Variable."""
221
+ """Create a graph for the log-CDF of a random variable.
222
+
223
+ Parameters
224
+ ----------
225
+ rv : TensorVariable
226
+ value : tensor_like
227
+ Should be the same type (shape and dtype) as the rv.
228
+ warn_rvs : bool, default True
229
+ Warn if RVs were found in the logcdf graph.
230
+ This can happen when a variable has other random variables as inputs.
231
+ In that case, those random variables should be replaced by their respective values.
232
+
233
+ Returns
234
+ -------
235
+ logp : TensorVariable
236
+
237
+ Raises
238
+ ------
239
+ RuntimeError
240
+ If the logcdf cannot be derived.
241
+
242
+ Examples
243
+ --------
244
+ Create a compiled function that evaluates the logcdf of a variable
245
+
246
+ .. code-block:: python
247
+
248
+ import pymc as pm
249
+ import pytensor.tensor as pt
250
+
251
+ mu = pt.scalar("mu")
252
+ rv = pm.Normal.dist(mu, 1.0)
253
+
254
+ value = pt.scalar("value")
255
+ rv_logcdf = pm.logcdf(rv, value)
256
+
257
+ # Use .eval() for debugging
258
+ print(rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.2034146
259
+
260
+ # Compile a function for repeated evaluations
261
+ rv_logcdf_fn = pm.compile_pymc([value, mu], rv_logcdf)
262
+ print(rv_logcdf_fn(value=0.9, mu=0.0)) # -0.2034146
263
+
264
+
265
+ Derive the graph for a transformation of a RandomVariable
266
+
267
+ .. code-block:: python
268
+
269
+ import pymc as pm
270
+ import pytensor.tensor as pt
271
+
272
+ mu = pt.scalar("mu")
273
+ rv = pm.Normal.dist(mu, 1.0)
274
+ exp_rv = pt.exp(rv)
275
+
276
+ value = pt.scalar("value")
277
+ exp_rv_logcdf = pm.logcdf(exp_rv, value)
278
+
279
+ # Use .eval() for debugging
280
+ print(exp_rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.78078813
281
+
282
+ # Compile a function for repeated evaluations
283
+ exp_rv_logcdf_fn = pm.compile_pymc([value, mu], exp_rv_logcdf)
284
+ print(exp_rv_logcdf_fn(value=0.9, mu=0.0)) # -0.78078813
285
+
286
+
287
+ Define a CustomDist logcdf
288
+
289
+ .. code-block:: python
290
+
291
+ import pymc as pm
292
+ import pytensor.tensor as pt
293
+
294
+ def normal_logcdf(value, mu, sigma):
295
+ return pm.logp(pm.Normal.dist(mu, sigma), value)
296
+
297
+ with pm.Model() as model:
298
+ mu = pm.Normal("mu")
299
+ sigma = pm.HalfNormal("sigma")
300
+ pm.CustomDist("x", mu, sigma, logcdf=normal_logcdf)
301
+
302
+ """
140
303
warn_rvs , kwargs = _deprecate_warn_missing_rvs (warn_rvs , kwargs )
141
304
value = pt .as_tensor_variable (value , dtype = rv .dtype )
142
305
try :
@@ -153,7 +316,72 @@ def logcdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Te
153
316
154
317
155
318
def icdf (rv : TensorVariable , value : TensorLike , warn_rvs = None , ** kwargs ) -> TensorVariable :
156
- """Create a graph for the inverse CDF of a Random Variable."""
319
+ """Create a graph for the inverse CDF of a random variable.
320
+
321
+ Parameters
322
+ ----------
323
+ rv : TensorVariable
324
+ value : tensor_like
325
+ Should be the same type (shape and dtype) as the rv.
326
+ warn_rvs : bool, default True
327
+ Warn if RVs were found in the icdf graph.
328
+ This can happen when a variable has other random variables as inputs.
329
+ In that case, those random variables should be replaced by their respective values.
330
+
331
+ Returns
332
+ -------
333
+ icdf : TensorVariable
334
+
335
+ Raises
336
+ ------
337
+ RuntimeError
338
+ If the icdf cannot be derived.
339
+
340
+ Examples
341
+ --------
342
+ Create a compiled function that evaluates the icdf of a variable
343
+
344
+ .. code-block:: python
345
+
346
+ import pymc as pm
347
+ import pytensor.tensor as pt
348
+
349
+ mu = pt.scalar("mu")
350
+ rv = pm.Normal.dist(mu, 1.0)
351
+
352
+ value = pt.scalar("value")
353
+ rv_icdf = pm.icdf(rv, value)
354
+
355
+ # Use .eval() for debugging
356
+ print(rv_icdf.eval({value: 0.9, mu: 0.0})) # 1.28155157
357
+
358
+ # Compile a function for repeated evaluations
359
+ rv_icdf_fn = pm.compile_pymc([value, mu], rv_icdf)
360
+ print(rv_icdf_fn(value=0.9, mu=0.0)) # 1.28155157
361
+
362
+
363
+ Derive the graph for a transformation of a RandomVariable
364
+
365
+ .. code-block:: python
366
+
367
+ import pymc as pm
368
+ import pytensor.tensor as pt
369
+
370
+ mu = pt.scalar("mu")
371
+ rv = pm.Normal.dist(mu, 1.0)
372
+ exp_rv = pt.exp(rv)
373
+
374
+ value = pt.scalar("value")
375
+ exp_rv_icdf = pm.icdf(exp_rv, value)
376
+
377
+ # Use .eval() for debugging
378
+ print(exp_rv_icdf.eval({value: 0.9, mu: 0.0})) # 3.60222448
379
+
380
+ # Compile a function for repeated evaluations
381
+ exp_rv_icdf_fn = pm.compile_pymc([value, mu], exp_rv_icdf)
382
+ print(exp_rv_icdf_fn(value=0.9, mu=0.0)) # 3.60222448
383
+
384
+ """
157
385
warn_rvs , kwargs = _deprecate_warn_missing_rvs (warn_rvs , kwargs )
158
386
value = pt .as_tensor_variable (value , dtype = "floatX" )
159
387
try :
@@ -208,7 +436,9 @@ def conditional_logp(
208
436
If we create a value variable for ``Y_rv``, i.e. ``y_vv = pt.scalar("y")``,
209
437
the graph of ``conditional_logp({Y_rv: y_vv})`` is equivalent to the
210
438
conditional log-probability :math:`\log p(Y = y \mid \Sigma^2)`, with a stochastic
211
- ``sigma2_rv``. If we specify a value variable for ``sigma2_rv``, i.e.
439
+ ``sigma2_rv``.
440
+
441
+ If we specify a value variable for ``sigma2_rv``, i.e.
212
442
``s_vv = pt.scalar("s2")``, then ``conditional_logp({Y_rv: y_vv, sigma2_rv: s_vv})``
213
443
yields the conditional log-probabilities of the two variables.
214
444
The sum of the two terms gives their joint log-probability.
@@ -221,11 +451,11 @@ def conditional_logp(
221
451
222
452
Parameters
223
453
----------
224
- rv_values
454
+ rv_values: dict
225
455
A ``dict`` of variables that maps stochastic elements
226
456
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
227
457
values in a log-probability.
228
- warn_rvs
458
+ warn_rvs : bool, default True
229
459
When ``True``, issue a warning when a `RandomVariable` is found in
230
460
the logp graph and doesn't have a corresponding value variable specified in
231
461
`rv_values`.
@@ -237,8 +467,9 @@ def conditional_logp(
237
467
238
468
Returns
239
469
-------
240
- A ``dict`` that maps each value variable to the conditional log-probability term derived
241
- from the respective `RandomVariable`.
470
+ values_to_logps: dict
471
+ A ``dict`` that maps each value variable to the conditional log-probability term derived
472
+ from the respective `RandomVariable`.
242
473
243
474
"""
244
475
warn_rvs , kwargs = _deprecate_warn_missing_rvs (warn_rvs , kwargs )
0 commit comments