|
15 | 15 | import collections |
16 | 16 | import logging |
17 | 17 | import time |
18 | | -import warnings as _warnings |
19 | 18 |
|
20 | 19 | from collections import Counter |
21 | 20 | from collections.abc import Callable, Iterator |
|
40 | 39 | from pymc.model import modelcontext |
41 | 40 | from pymc.model.core import Point |
42 | 41 | from pymc.pytensorf import ( |
43 | | - compile_pymc, |
| 42 | + compile, |
44 | 43 | find_rng_nodes, |
45 | 44 | reseed_rngs, |
46 | 45 | ) |
|
76 | 75 | ) |
77 | 76 |
|
78 | 77 | logger = logging.getLogger(__name__) |
79 | | -_warnings.filterwarnings( |
80 | | - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" |
81 | | -) |
82 | 78 |
|
83 | 79 | REGULARISATION_TERM = 1e-8 |
84 | 80 | DEFAULT_LINKER = "cvm_nogc" |
@@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs( |
142 | 138 | [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)], |
143 | 139 | model.value_vars, |
144 | 140 | ) |
145 | | - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) |
| 141 | + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) |
146 | 142 | logp_dlogp_fn.trust_input = True |
147 | 143 |
|
148 | 144 | return logp_dlogp_fn |
@@ -502,9 +498,10 @@ def bfgs_sample_dense( |
502 | 498 |
|
503 | 499 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) |
504 | 500 |
|
505 | | - with _warnings.catch_warnings(): |
506 | | - _warnings.simplefilter("ignore", category=FutureWarning) |
507 | | - mu = x - pt.batched_dot(H_inv, g) |
| 501 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 502 | + |
| 503 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 504 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
508 | 505 |
|
509 | 506 | phi = pt.matrix_transpose( |
510 | 507 | # (L, N, 1) |
@@ -573,17 +570,16 @@ def bfgs_sample_sparse( |
573 | 570 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) |
574 | 571 | logdet += pt.sum(pt.log(alpha), axis=-1) |
575 | 572 |
|
| 573 | + # inverse Hessian |
| 574 | + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
| 575 | + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) |
| 576 | + |
576 | 577 | # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. |
577 | | - with _warnings.catch_warnings(): |
578 | | - _warnings.simplefilter("ignore", category=FutureWarning) |
579 | | - mu = x - ( |
580 | | - # (L, N), (L, N) -> (L, N) |
581 | | - pt.batched_dot(alpha_diag, g) |
582 | | - # beta @ gamma @ beta.T |
583 | | - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
584 | | - # (L, N, N), (L, N) -> (L, N) |
585 | | - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) |
586 | | - ) |
| 578 | + |
| 579 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 580 | + |
| 581 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 582 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
587 | 583 |
|
588 | 584 | phi = pt.matrix_transpose( |
589 | 585 | # (L, N, 1) |
@@ -857,7 +853,7 @@ def make_pathfinder_body( |
857 | 853 |
|
858 | 854 | # return psi, logP_psi, logQ_psi, elbo_argmax |
859 | 855 |
|
860 | | - pathfinder_body_fn = compile_pymc( |
| 856 | + pathfinder_body_fn = compile( |
861 | 857 | [x_full, g_full], |
862 | 858 | [psi, logP_psi, logQ_psi, elbo_argmax], |
863 | 859 | **compile_kwargs, |
|
0 commit comments