2323from pymc .initial_point import make_initial_point_fn
2424from pymc .model .transform .conditioning import remove_value_transforms
2525from pymc .model .transform .optimization import freeze_dims_and_data
26+ from pymc .pytensorf import join_nonshared_inputs
2627from pymc .sampling .jax import get_jaxified_graph
2728from pymc .util import get_default_varnames
2829from pytensor .tensor import TensorVariable
3233_log = logging .getLogger (__name__ )
3334
3435
35- def get_near_psd (A : np .ndarray ) -> np .ndarray :
36+ def get_nearest_psd (A : np .ndarray ) -> np .ndarray :
3637 """
3738 Compute the nearest positive semi-definite matrix to a given matrix.
3839
39- This function takes a square matrix and returns the nearest positive
40- semi-definite matrix using eigenvalue decomposition. It ensures all
41- eigenvalues are non-negative. The "nearest" matrix is defined in terms
40+ This function takes a square matrix and returns the nearest positive semi-definite matrix using
41+ eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
4242 of the Frobenius norm.
4343
4444 Parameters
@@ -58,23 +58,13 @@ def get_near_psd(A: np.ndarray) -> np.ndarray:
5858 return eigvec @ np .diag (eigval ) @ eigvec .T
5959
6060
61- def _get_unravel_rv_info (optimized_point , variables , model ):
62- cursor = 0
63- slices = {}
64- out_shapes = {}
65-
66- for i , var in enumerate (variables ):
67- raveled_shape = np .prod (optimized_point [var .name ].shape ).astype (int )
68- rv = model .values_to_rvs .get (var , var )
69-
70- idx = slice (cursor , cursor + raveled_shape )
71- slices [rv ] = idx
72- out_shapes [rv ] = tuple (
73- [len (model .coords [dim ]) for dim in model .named_vars_to_dims .get (rv .name , [])]
74- )
75- cursor += raveled_shape
61+ def _unconstrained_vector_to_constrained_rvs (model ):
62+ constrained_rvs , unconstrained_vector = join_nonshared_inputs (
63+ model .initial_point (), inputs = model .value_vars , outputs = model .unobserved_value_vars
64+ )
7665
77- return slices , out_shapes
66+ unconstrained_vector .name = "unconstrained_vector"
67+ return constrained_rvs , unconstrained_vector
7868
7969
8070def _create_transformed_draws (H_inv , slices , out_shapes , posterior_draws , model , chains , draws ):
@@ -94,37 +84,24 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
9484 return f_untransform (posterior_draws )
9585
9686
97- def fit_laplace (
87+ def jax_fit_mvn_to_MAP (
9888 optimized_point : dict [str , np .ndarray ],
9989 model : pm .Model ,
100- chains : int = 2 ,
101- draws : int = 500 ,
10290 on_bad_cov : Literal ["warn" , "error" , "ignore" ] = "ignore" ,
10391 transform_samples : bool = True ,
10492 zero_tol : float = 1e-8 ,
10593 diag_jitter : float | None = 1e-8 ,
106- progressbar : bool = True ,
107- mode : str = "JAX" ,
108- ) -> az .InferenceData :
94+ ) -> tuple [RaveledVars , np .ndarray ]:
10995 """
110- Compute the Laplace approximation of the posterior distribution.
111-
112- The posterior distribution will be approximated as a Gaussian
113- distribution centered at the posterior mode.
114- The covariance is the inverse of the negative Hessian matrix of
115- the log-posterior evaluated at the mode.
96+ Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
97+ evaluated at the MAP estimate. This is the basis of the Laplace approximation.
11698
11799 Parameters
118100 ----------
119101 optimized_point : dict[str, np.ndarray]
120- Local maximum a posteriori (MAP) point returned from pymc.find_MAP
121- or jax_tools.fit_map
102+ Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
122103 model : Model
123104 A PyMC model
124- chains : int
125- The number of sampling chains running in parallel. Default is 2.
126- draws : int
127- The number of samples to draw from the approximated posterior. Default is 500.
128105 on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
129106 What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
130107 If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -137,18 +114,17 @@ def fit_laplace(
137114 diag_jitter: float | None
138115 A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
139116 If None, no jitter is added. Default is 1e-8.
140- progressbar : bool
141- Whether or not to display progress bar. Default is True.
142- mode : str
143- Computation backend mode. Default is "JAX".
144117
145118 Returns
146119 -------
147- InferenceData
148- arviz.InferenceData object storing posterior, observed_data, and constant_data groups .
120+ map_estimate: RaveledVars
121+ The MAP estimate of the model parameters, raveled into a 1D array .
149122
123+ inverse_hessian: np.ndarray
124+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
150125 """
151126 frozen_model = freeze_dims_and_data (model )
127+
152128 if not transform_samples :
153129 untransformed_model = remove_value_transforms (frozen_model )
154130 logp = untransformed_model .logp (jacobian = False )
@@ -157,19 +133,17 @@ def fit_laplace(
157133 logp = frozen_model .logp (jacobian = True )
158134 variables = frozen_model .continuous_value_vars
159135
160- mu = np .concatenate (
161- [np .atleast_1d (optimized_point [var .name ]).ravel () for var in variables ], axis = 0
136+ mu = DictToArrayBijection .map (optimized_point )
137+
138+ [neg_logp ], flat_inputs = join_nonshared_inputs (
139+ point = frozen_model .initial_point (), outputs = [- logp ], inputs = variables
162140 )
163141
164142 f_logp , f_grad , f_hess , f_hessp = make_jax_funcs_from_graph (
165- cast (TensorVariable , logp ),
166- use_grad = True ,
167- use_hess = True ,
168- use_hessp = False ,
169- inputs = variables ,
143+ neg_logp , use_grad = True , use_hess = True , use_hessp = False , inputs = [flat_inputs ]
170144 )
171145
172- H = f_hess (mu )
146+ H = - f_hess (mu . data )
173147 H_inv = np .linalg .pinv (np .where (np .abs (H ) < zero_tol , 0 , - H ))
174148
175149 def stabilize (x , jitter ):
@@ -184,73 +158,111 @@ def stabilize(x, jitter):
184158 raise np .linalg .LinAlgError (
185159 "Inverse Hessian not positive-semi definite at the provided point"
186160 )
187- H_inv = get_near_psd (H_inv )
161+ H_inv = get_nearest_psd (H_inv )
188162 if on_bad_cov == "warn" :
189163 _log .warning (
190164 "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
191165 "matrix in L1-norm instead"
192166 )
193167
194- posterior_dist = stats .multivariate_normal (mean = mu , cov = H_inv , allow_singular = True )
168+ return mu , H_inv
169+
170+
171+ def jax_laplace (
172+ mu : RaveledVars ,
173+ H_inv : np .ndarray ,
174+ model : pm .Model ,
175+ chains : int = 2 ,
176+ draws : int = 500 ,
177+ transform_samples : bool = True ,
178+ progressbar : bool = True ,
179+ ) -> az .InferenceData :
180+ """
181+
182+ Parameters
183+ ----------
184+ mu
185+ H_inv
186+ model : Model
187+ A PyMC model
188+ chains : int
189+ The number of sampling chains running in parallel. Default is 2.
190+ draws : int
191+ The number of samples to draw from the approximated posterior. Default is 500.
192+ transform_samples : bool
193+ Whether to transform the samples back to the original parameter space. Default is True.
194+
195+ Returns
196+ -------
197+ idata: az.InferenceData
198+ An InferenceData object containing the approximated posterior samples.
199+ """
200+ posterior_dist = stats .multivariate_normal (mean = mu .data , cov = H_inv , allow_singular = True )
195201 posterior_draws = posterior_dist .rvs (size = (chains , draws ))
196- slices , out_shapes = _get_unravel_rv_info (optimized_point , variables , frozen_model )
197202
198203 if transform_samples :
199- posterior_draws = _create_transformed_draws (
200- H_inv , slices , out_shapes , posterior_draws , frozen_model , chains , draws
201- )
204+ constrained_rvs , unconstrained_vector = _unconstrained_vector_to_constrained_rvs (model )
205+ f_constrain = get_jaxified_graph (inputs = [unconstrained_vector ], outputs = constrained_rvs )
206+
207+ posterior_draws = jax .jit (jax .vmap (jax .vmap (f_constrain )))(posterior_draws )
208+
202209 else :
210+ info = mu .point_map_info
211+ flat_shapes = [np .prod (shape ).astype (int ) for _ , shape , _ in info ]
212+ slices = [
213+ slice (sum (flat_shapes [:i ]), sum (flat_shapes [: i + 1 ])) for i in range (len (flat_shapes ))
214+ ]
215+
203216 posterior_draws = [
204- posterior_draws [..., idx ].reshape ((chains , draws , * out_shapes . get ( rv , ())) )
205- for rv , idx in slices . items ( )
217+ posterior_draws [..., idx ].reshape ((chains , draws , * shape )). astype ( dtype )
218+ for idx , ( name , shape , dtype ) in zip ( slices , info )
206219 ]
207220
208- def make_rv_coords (rv ):
221+ def make_rv_coords (name ):
209222 coords = {"chain" : range (chains ), "draw" : range (draws )}
210- extra_dims = frozen_model .named_vars_to_dims .get (rv . name )
223+ extra_dims = model .named_vars_to_dims .get (name )
211224 if extra_dims is None :
212225 return coords
213- return coords | {dim : list (frozen_model .coords [dim ]) for dim in extra_dims }
226+ return coords | {dim : list (model .coords [dim ]) for dim in extra_dims }
214227
215- def make_rv_dims (rv ):
228+ def make_rv_dims (name ):
216229 dims = ["chain" , "draw" ]
217- extra_dims = frozen_model .named_vars_to_dims .get (rv . name )
230+ extra_dims = model .named_vars_to_dims .get (name )
218231 if extra_dims is None :
219232 return dims
220233 return dims + list (extra_dims )
221234
222235 idata = {
223- rv . name : xr .DataArray (
236+ name : xr .DataArray (
224237 data = draws .squeeze (),
225- coords = make_rv_coords (rv ),
226- dims = make_rv_dims (rv ),
227- name = rv . name ,
238+ coords = make_rv_coords (name ),
239+ dims = make_rv_dims (name ),
240+ name = name ,
228241 )
229- for rv , draws in zip (slices . keys () , posterior_draws )
242+ for ( name , _ , _ ), draws in zip (mu . point_map_info , posterior_draws )
230243 }
231244
232- coords , dims = coords_and_dims_for_inferencedata (frozen_model )
245+ coords , dims = coords_and_dims_for_inferencedata (model )
233246 idata = az .convert_to_inference_data (idata , coords = coords , dims = dims )
234247
235- if frozen_model .deterministics :
248+ if model .deterministics :
236249 idata .posterior = pm .compute_deterministics (
237250 idata .posterior ,
238- model = frozen_model ,
251+ model = model ,
239252 merge_dataset = True ,
240253 progressbar = progressbar ,
241- compile_kwargs = {"mode" : mode },
242254 )
243255
244256 observed_data = dict_to_dataset (
245- find_observations (frozen_model ),
257+ find_observations (model ),
246258 library = pm ,
247259 coords = coords ,
248260 dims = dims ,
249261 default_dims = [],
250262 )
251263
252264 constant_data = dict_to_dataset (
253- find_constants (frozen_model ),
265+ find_constants (model ),
254266 library = pm ,
255267 coords = coords ,
256268 dims = dims ,
@@ -266,6 +278,29 @@ def make_rv_dims(rv):
266278 return idata
267279
268280
281+ def fit_laplace (
282+ optimized_point : dict [str , np .ndarray ],
283+ model : pm .Model ,
284+ chains : int = 2 ,
285+ draws : int = 500 ,
286+ on_bad_cov : Literal ["warn" , "error" , "ignore" ] = "ignore" ,
287+ transform_samples : bool = True ,
288+ zero_tol : float = 1e-8 ,
289+ diag_jitter : float | None = 1e-8 ,
290+ progressbar : bool = True ,
291+ ) -> az .InferenceData :
292+ mu , H_inv = jax_fit_mvn_to_MAP (
293+ optimized_point ,
294+ model ,
295+ on_bad_cov ,
296+ transform_samples ,
297+ zero_tol ,
298+ diag_jitter ,
299+ )
300+
301+ return jax_laplace (mu , H_inv , model , chains , draws , transform_samples , progressbar )
302+
303+
269304def make_jax_funcs_from_graph (
270305 graph : TensorVariable ,
271306 use_grad : bool ,
@@ -280,34 +315,19 @@ def make_jax_funcs_from_graph(
280315 if not isinstance (inputs , list ):
281316 inputs = [inputs ]
282317
283- f = cast (Callable , get_jaxified_graph (inputs = inputs , outputs = [graph ]))
284- input_shapes = [x .type .shape for x in inputs ]
285-
286- def at_least_tuple (x ):
287- if isinstance (x , tuple | list ):
288- return x
289- return (x ,)
318+ f_tuple = cast (Callable , get_jaxified_graph (inputs = inputs , outputs = [graph ]))
290319
291- assert all ([xi is not None for x in input_shapes for xi in at_least_tuple (x )])
320+ def f (* args , ** kwargs ):
321+ return f_tuple (* args , ** kwargs )[0 ]
292322
293- def f_jax (x ):
294- args = []
295- cursor = 0
296- for shape in input_shapes :
297- n_elements = int (np .prod (shape ))
298- s = slice (cursor , cursor + n_elements )
299- args .append (x [s ].reshape (shape ))
300- cursor += n_elements
301- return f (* args )[0 ]
302-
303- f_logp = jax .jit (f_jax )
323+ f_logp = jax .jit (f )
304324
305325 f_grad = None
306326 f_hess = None
307327 f_hessp = None
308328
309329 if use_grad :
310- _f_grad_jax = jax .grad (f_jax )
330+ _f_grad_jax = jax .grad (f )
311331
312332 def f_grad_jax (x ):
313333 return jax .numpy .stack (_f_grad_jax (x ))
@@ -411,14 +431,12 @@ def find_MAP(
411431 {var_name : value for var_name , value in start_dict .items () if var_name in vars_dict }
412432 )
413433
414- inputs = [frozen_model .values_to_rvs [vars_dict [x ]] for x in start_dict .keys ()]
415- inputs = [frozen_model .rvs_to_values [x ] for x in inputs ]
416-
417- logp_factors = frozen_model .logp (sum = False , jacobian = False )
418- neg_logp = - pt .sum ([pt .sum (factor ) for factor in logp_factors ])
434+ [neg_logp ], inputs = join_nonshared_inputs (
435+ point = start_dict , outputs = [- frozen_model .logp ()], inputs = frozen_model .continuous_value_vars
436+ )
419437
420438 f_logp , f_grad , f_hess , f_hessp = make_jax_funcs_from_graph (
421- neg_logp , use_grad , use_hess , use_hessp , inputs = inputs
439+ neg_logp , use_grad , use_hess , use_hessp , inputs = [ inputs ]
422440 )
423441
424442 args = optimizer_kwargs .pop ("args" , None )
@@ -435,11 +453,12 @@ def find_MAP(
435453 ** optimizer_kwargs ,
436454 )
437455
438- initial_point = RaveledVars (optimizer_result .x , initial_params .point_map_info )
456+ raveled_optimized = RaveledVars (optimizer_result .x , initial_params .point_map_info )
439457 unobserved_vars = get_default_varnames (model .unobserved_value_vars , include_transformed )
440458 unobserved_vars_values = model .compile_fn (unobserved_vars )(
441- DictToArrayBijection .rmap (initial_point , start_dict )
459+ DictToArrayBijection .rmap (raveled_optimized )
442460 )
461+
443462 optimized_point = {
444463 var .name : value for var , value in zip (unobserved_vars , unobserved_vars_values )
445464 }
0 commit comments