11import logging
22
33from collections .abc import Callable
4+ from importlib .util import find_spec
45from typing import Literal , cast , get_args
56
6- import jax
77import numpy as np
88import pymc as pm
99import pytensor
3030def set_optimizer_function_defaults (method , use_grad , use_hess , use_hessp ):
3131 method_info = MINIMIZE_MODE_KWARGS [method ].copy ()
3232
33- use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
34- use_hess = use_hess if use_hess is not None else method_info ["uses_hess" ]
35- use_hessp = use_hessp if use_hessp is not None else method_info ["uses_hessp" ]
36-
3733 if use_hess and use_hessp :
34+ _log .warning (
35+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37+ 'Setting "use_hess" to False.'
38+ )
3839 use_hess = False
3940
41+ use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
42+
43+ if use_hessp is not None and use_hess is None :
44+ use_hess = not use_hessp
45+
46+ elif use_hess is not None and use_hessp is None :
47+ use_hessp = not use_hess
48+
49+ elif use_hessp is None and use_hess is None :
50+ use_hessp = method_info ["uses_hessp" ]
51+ use_hess = method_info ["uses_hess" ]
52+ if use_hessp and use_hess :
53+ # If a method could use either hess or hessp, we default to using hessp
54+ use_hess = False
55+
4056 return use_grad , use_hess , use_hessp
4157
4258
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5975 The nearest positive semi-definite matrix to the input matrix.
6076 """
6177 C = (A + A .T ) / 2
62- eigval , eigvec = np .linalg .eig (C )
78+ eigval , eigvec = np .linalg .eigh (C )
6379 eigval [eigval < 0 ] = 0
6480
6581 return eigvec @ np .diag (eigval ) @ eigvec .T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97113 return f_untransform (posterior_draws )
98114
99115
100- def _compile_jax_gradients (
116+ def _compile_grad_and_hess_to_jax (
101117 f_loss : Function , use_hess : bool , use_hessp : bool
102118) -> tuple [Callable | None , Callable | None ]:
103119 """
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122138 f_hessp: Callable | None
123139 The compiled hessian-vector product function, or None if use_hessp is False.
124140 """
141+ import jax
142+
125143 f_hess = None
126144 f_hessp = None
127145
@@ -152,7 +170,7 @@ def f_hess_jax(x):
152170 return f_loss_and_grad , f_hess , f_hessp
153171
154172
155- def _compile_functions (
173+ def _compile_functions_for_scipy_optimize (
156174 loss : TensorVariable ,
157175 inputs : list [TensorVariable ],
158176 compute_grad : bool ,
@@ -177,7 +195,7 @@ def _compile_functions(
177195 compute_hessp: bool
178196 Whether to compile a function that computes the Hessian-vector product of the loss function.
179197 compile_kwargs: dict, optional
180- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
198+ Additional keyword arguments to pass to the ``pm.compile `` function.
181199
182200 Returns
183201 -------
@@ -193,19 +211,19 @@ def _compile_functions(
193211 if compute_grad :
194212 grads = pytensor .gradient .grad (loss , inputs )
195213 grad = pt .concatenate ([grad .ravel () for grad in grads ])
196- f_loss_and_grad = pm .compile_pymc (inputs , [loss , grad ], ** compile_kwargs )
214+ f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
197215 else :
198- f_loss = pm .compile_pymc (inputs , loss , ** compile_kwargs )
216+ f_loss = pm .compile (inputs , loss , ** compile_kwargs )
199217 return [f_loss ]
200218
201219 if compute_hess :
202220 hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
203- f_hess = pm .compile_pymc (inputs , hess , ** compile_kwargs )
221+ f_hess = pm .compile (inputs , hess , ** compile_kwargs )
204222
205223 if compute_hessp :
206224 p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
207225 hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
208- f_hessp = pm .compile_pymc ([* inputs , p ], hessp [0 ], ** compile_kwargs )
226+ f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
209227
210228 return [f_loss_and_grad , f_hess , f_hessp ]
211229
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240258 gradient_backend: str, default "pytensor"
241259 Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242260 compile_kwargs:
243- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
261+ Additional keyword arguments to pass to the ``pm.compile `` function.
244262
245263 Returns
246264 -------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265283 )
266284
267285 use_jax_gradients = (gradient_backend == "jax" ) and use_grad
286+ if use_jax_gradients and not find_spec ("jax" ):
287+ raise ImportError ("JAX must be installed to use JAX gradients" )
268288
269289 mode = compile_kwargs .get ("mode" , None )
270290 if mode is None and use_jax_gradients :
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285305 compute_hess = use_hess and not use_jax_gradients
286306 compute_hessp = use_hessp and not use_jax_gradients
287307
288- funcs = _compile_functions (
308+ funcs = _compile_functions_for_scipy_optimize (
289309 loss = loss ,
290310 inputs = [flat_input ],
291311 compute_grad = compute_grad ,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301321
302322 if use_jax_gradients :
303323 # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304- f_loss , f_hess , f_hessp = _compile_jax_gradients (f_loss , use_hess , use_hessp )
324+ f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
305325
306326 return f_loss , f_hess , f_hessp
307327
0 commit comments