Skip to content

Commit

Permalink
* handle another edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 18, 2024
1 parent e0c0835 commit 86eb418
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,17 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None,
# real-valued positive-definite linear operators are symmetric.
def real_valued(x):
return not issubclass(jnp.result_type(x).type, np.complexfloating)
if callable(A) and x0 is not None:
# we use output dtype as the proxy for dtype.
symmetric = all(map(real_valued,
tree_leaves(jax.eval_shape(A, x0))))
else:
try:
# Prefer to use the dtype of the operator, if available.
symmetric = real_valued(A)
except AttributeError:
# fall back to the RHS.
symmetric = all(map(real_valued, tree_leaves(b)))

try:
# Prefer to use the dtype of the operator, if available.
if callable(A) and x0 is not None:
symmetric = all(map(real_valued,
tree_leaves(jax.eval_shape(A, x0))))
else:
symmetric = real_valued(A)
except TypeError:
# fall back to the RHS.
symmetric = all(map(real_valued, tree_leaves(b)))
return _isolve(_cg_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M, symmetric=symmetric)
Expand Down

0 comments on commit 86eb418

Please sign in to comment.