You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Instead of taking an explicit symmetric argument, which could be passed from user, or some auxillary information to determine the dtype of linear operator the RHS's dtype is used as a proxy for the dtype of the linear operator when check_symmetric=True. However, it is entirely possible that implicit up-casting is expected by user, such that linear operator is real, but RHS is complex, or that linear operator is symmetric but complex, in both cases matvec == vecmat for positive definite linear operators. This could be a problem for some users of jsp.solve_cg.
# real-valued positive-definite linear operators are symmetricdefreal_valued(x):
returnnotissubclass(x.dtype.type, np.complexfloating)
symmetric=all(map(real_valued, tree_leaves(b))) \
ifcheck_symmetricelseFalsex=lax.custom_linear_solve(
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
symmetric=symmetric)
System info (python version, jaxlib version, accelerator, etc.)
jax==0.4.31
jaxlib==0.4.31
The text was updated successfully, but these errors were encountered:
Thanks for the report! So the tricky thing here is that jax.scipy.sparse.linalg.cg implements the API of scipy.sparse.linalg.cg, which has no explicit symmetric flag. We could probably add an optional symmetric argument to JAX's version that lets the user override the default. What do you think?
Description
Instead of taking an explicit
symmetric
argument, which could be passed from user, or some auxillary information to determine the dtype of linear operator the RHS's dtype is used as a proxy for the dtype of the linear operator whencheck_symmetric=True
. However, it is entirely possible that implicit up-casting is expected by user, such that linear operator is real, but RHS is complex, or that linear operator is symmetric but complex, in both casesmatvec == vecmat
for positive definite linear operators. This could be a problem for some users ofjsp.solve_cg
.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: