Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def solve(a, b):
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite

def solve_triangular(A, b):
return jax.scipy.linalg.solve_triangular(
Expand All @@ -101,7 +100,7 @@ def solve_triangular(A, b):
lower=lower,
trans=0, # this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal=unit_diagonal,
check_finite=check_finite,
check_finite=False,
)

return solve_triangular
Expand Down Expand Up @@ -132,27 +131,23 @@ def pivot_to_permutations(pivots):
def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l
p_indices = op.p_indices
check_finite = op.check_finite

if p_indices:
raise ValueError("JAX does not support the p_indices argument")

def lu(*inputs):
return jax.scipy.linalg.lu(
*inputs, permute_l=permute_l, check_finite=check_finite
)
return jax.scipy.linalg.lu(*inputs, permute_l=permute_l, check_finite=False)

return lu


@jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a

def lu_factor(a):
return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a
a, check_finite=False, overwrite_a=overwrite_a
)

return lu_factor
Expand All @@ -161,12 +156,11 @@ def lu_factor(a):
@jax_funcify.register(CholeskySolve)
def jax_funcify_ChoSolve(op, **kwargs):
lower = op.lower
check_finite = op.check_finite
overwrite_b = op.overwrite_b

def cho_solve(c, b):
return jax.scipy.linalg.cho_solve(
(c, lower), b, check_finite=check_finite, overwrite_b=overwrite_b
(c, lower), b, check_finite=False, overwrite_b=overwrite_b
)

return cho_solve
Expand Down
201 changes: 0 additions & 201 deletions pytensor/link/numba/dispatch/linalg/_LAPACK.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,122 +263,6 @@ def potrs(UPLO, N, NRHS, A, LDA, B, LDB, INFO):

return potrs

@classmethod
def numba_xlange(cls, dtype) -> CPUDispatcher:
"""
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
a general M-by-N matrix A.

Called by scipy.linalg.solve, but doesn't correspond to any Op in pytensor.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
float_pointer = _get_nb_float_from_dtype(kind, return_pointer=True)
unique_func_name = f"scipy.lapack.{kind}lange"

@numba_basic.numba_njit
def get_lange_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lange")
return ptr

lange_function_type = types.FunctionType(
float_type(
nb_i32p, # NORM
nb_i32p, # M
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # WORK
)
)

@numba_basic.numba_njit
def lange(NORM, M, N, A, LDA, WORK):
fn = _call_cached_ptr(
get_ptr_func=get_lange_pointer,
func_type_ref=lange_function_type,
unique_func_name_lit=unique_func_name,
)
return fn(NORM, M, N, A, LDA, WORK)

return lange

@classmethod
def numba_xlamch(cls, dtype) -> CPUDispatcher:
"""
Determine machine precision for floating point arithmetic.
"""
kind = get_blas_kind(dtype)
float_type = _get_nb_float_from_dtype(kind, return_pointer=False)
unique_func_name = f"scipy.lapack.{kind}lamch"

@numba_basic.numba_njit
def get_lamch_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "lamch")
return ptr

lamch_function_type = types.FunctionType(
float_type( # Return type
nb_i32p, # CMACH
)
)

@numba_basic.numba_njit
def lamch(CMACH):
fn = _call_cached_ptr(
get_ptr_func=get_lamch_pointer,
func_type_ref=lamch_function_type,
unique_func_name_lit=unique_func_name,
)
res = fn(CMACH)
return res

return lamch

@classmethod
def numba_xgecon(cls, dtype) -> CPUDispatcher:
"""
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.

Called by scipy.linalg.solve when assume_a == "gen"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}gecon"

@numba_basic.numba_njit
def get_gecon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gecon")
return ptr

gecon_function_type = types.FunctionType(
types.void(
nb_i32p, # NORM
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def gecon(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_gecon_pointer,
func_type_ref=gecon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(NORM, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)

return gecon

@classmethod
def numba_xgetrf(cls, dtype) -> CPUDispatcher:
"""
Expand Down Expand Up @@ -506,91 +390,6 @@ def sysv(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, LWORK, INFO):

return sysv

@classmethod
def numba_xsycon(cls, dtype) -> CPUDispatcher:
"""
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
computed by xSYTRF.
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}sycon"

@numba_basic.numba_njit
def get_sycon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "sycon")
return ptr

sycon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
nb_i32p, # IPIV
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def sycon(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_sycon_pointer,
func_type_ref=sycon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, IPIV, ANORM, RCOND, WORK, IWORK, INFO)

return sycon

@classmethod
def numba_xpocon(cls, dtype) -> CPUDispatcher:
"""
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
computed by potrf.

Called by scipy.linalg.solve when assume_a == "pos"
"""
kind = get_blas_kind(dtype)
float_pointer = _get_nb_float_from_dtype(kind)
unique_func_name = f"scipy.lapack.{kind}pocon"

@numba_basic.numba_njit
def get_pocon_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "pocon")
return ptr

pocon_function_type = types.FunctionType(
types.void(
nb_i32p, # UPLO
nb_i32p, # N
float_pointer, # A
nb_i32p, # LDA
float_pointer, # ANORM
float_pointer, # RCOND
float_pointer, # WORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def pocon(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO):
fn = _call_cached_ptr(
get_ptr_func=get_pocon_pointer,
func_type_ref=pocon_function_type,
unique_func_name_lit=unique_func_name,
)
fn(UPLO, N, A, LDA, ANORM, RCOND, WORK, IWORK, INFO)

return pocon

@classmethod
def numba_xposv(cls, dtype) -> CPUDispatcher:
"""
Expand Down
24 changes: 11 additions & 13 deletions pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,19 @@
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return (
linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
),
0,
)
def _cholesky(a, lower=False, overwrite_a=False):
return linalg.cholesky(a, lower=lower, overwrite_a=overwrite_a, check_finite=False)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
def cholesky_impl(A, lower=0, overwrite_a=False):
ensure_lapack()
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype

numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=False, overwrite_a=False, check_finite=True):
def impl(A, lower=False, overwrite_a=False):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
Expand Down Expand Up @@ -58,6 +53,10 @@ def impl(A, lower=False, overwrite_a=False, check_finite=True):
INFO,
)

if int_ptr_to_val(INFO) != 0:
A_copy = np.full_like(A_copy, np.nan)
return A_copy

if lower:
for j in range(1, _N):
for i in range(j):
Expand All @@ -67,10 +66,9 @@ def impl(A, lower=False, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0

info_int = int_ptr_to_val(INFO)

if transposed:
return A_copy.T, info_int
return A_copy, info_int
return A_copy.T
else:
return A_copy

return impl
Loading
Loading