Skip to content

Add Ops for LU Factorization #1218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import (
LU,
BlockDiagonal,
Cholesky,
Eigvalsh,
LUFactor,
PivotToPermutations,
Solve,
SolveTriangular,
)
Expand Down Expand Up @@ -93,3 +96,46 @@
return jax.scipy.linalg.block_diag(*inputs)

return block_diag


@jax_funcify.register(PivotToPermutations)
def jax_funcify_PivotToPermutation(op, **kwargs):
inverse = op.inverse

def pivot_to_permutations(pivots):
p_inv = jax.lax.linalg.lu_pivots_to_permutation(pivots, pivots.shape[0])
if inverse:
return p_inv
return jax.numpy.argsort(p_inv)

Check warning on line 109 in pytensor/link/jax/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/slinalg.py#L109

Added line #L109 was not covered by tests

return pivot_to_permutations


@jax_funcify.register(LU)
def jax_funcify_LU(op, **kwargs):
permute_l = op.permute_l
p_indices = op.p_indices
check_finite = op.check_finite

if p_indices:
raise ValueError("JAX does not support the p_indices argument")

def lu(*inputs):
return jax.scipy.linalg.lu(
*inputs, permute_l=permute_l, check_finite=check_finite
)

return lu


@jax_funcify.register(LUFactor)
def jax_funcify_LUFactor(op, **kwargs):
check_finite = op.check_finite
overwrite_a = op.overwrite_a

def lu_factor(a):
return jax.scipy.linalg.lu_factor(
a, check_finite=check_finite, overwrite_a=overwrite_a
)

return lu_factor
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
Expand Down
206 changes: 206 additions & 0 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from collections.abc import Callable
from typing import cast as typing_cast

import numpy as np
from numba import njit as numba_njit
from numba.core.extending import overload
from numba.np.linalg import ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix


@numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)

Check warning on line 16 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L16

Added line #L16 was not covered by tests
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv

Check warning on line 19 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L18-L19

Added lines #L18 - L19 were not covered by tests


@numba_njit
def _lu_factor_to_lu(a, dtype, overwrite_a):
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)

Check warning on line 24 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L24

Added line #L24 was not covered by tests

L = np.eye(A_copy.shape[-1], dtype=dtype)
L += np.tril(A_copy, k=-1)
U = np.triu(A_copy)

Check warning on line 28 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L26-L28

Added lines #L26 - L28 were not covered by tests

# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
perm = np.argsort(p_inv)

Check warning on line 33 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L31-L33

Added lines #L31 - L33 were not covered by tests

return perm, L, U

Check warning on line 35 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L35

Added line #L35 was not covered by tests


def _lu_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.

Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)


def _lu_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.

Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)


def _lu_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.

Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return typing_cast(
tuple[np.ndarray, np.ndarray, np.ndarray],
linalg.lu(
a,
permute_l=permute_l,
check_finite=check_finite,
p_indices=p_indices,
overwrite_a=overwrite_a,
),
)


@overload(_lu_1)
def lu_impl_1(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype

def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
return perm, L, U

Check warning on line 139 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L138-L139

Added lines #L138 - L139 were not covered by tests

return impl


@overload(_lu_2)
def lu_impl_2(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
"""

ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype

def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
PL = L[perm]

Check warning on line 169 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L168-L169

Added lines #L168 - L169 were not covered by tests

return PL, U

Check warning on line 171 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L171

Added line #L171 was not covered by tests

return impl


@overload(_lu_3)
def lu_impl_3(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> Callable[
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
dtype = a.dtype

def impl(
a: np.ndarray,
permute_l: bool,
check_finite: bool,
p_indices: bool,
overwrite_a: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
P = np.eye(a.shape[-1], dtype=dtype)[perm]

Check warning on line 202 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L201-L202

Added lines #L201 - L202 were not covered by tests

return P, L, U

Check warning on line 204 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L204

Added line #L204 was not covered by tests

return impl
86 changes: 86 additions & 0 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Callable

import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)


def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
"""
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)

Check warning on line 25 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L24-L25

Added lines #L24 - L25 were not covered by tests

return A_copy, ipiv, info

Check warning on line 27 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L27

Added line #L27 was not covered by tests


@overload(_getrf)
def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)

def impl(
A: np.ndarray, overwrite_a: bool = False
) -> tuple[np.ndarray, np.ndarray, int]:
_M, _N = np.int32(A.shape[-2:]) # type: ignore

Check warning on line 43 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L43

Added line #L43 was not covered by tests

if overwrite_a and A.flags.f_contiguous:
A_copy = A

Check warning on line 46 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L46

Added line #L46 was not covered by tests
else:
A_copy = _copy_to_fortran_order(A)

Check warning on line 48 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L48

Added line #L48 was not covered by tests

M = val_to_int_ptr(_M) # type: ignore
N = val_to_int_ptr(_N) # type: ignore
LDA = val_to_int_ptr(_M) # type: ignore
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)

Check warning on line 54 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L50-L54

Added lines #L50 - L54 were not covered by tests

numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)

Check warning on line 56 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L56

Added line #L56 was not covered by tests

return A_copy, IPIV, int_ptr_to_val(INFO)

Check warning on line 58 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L58

Added line #L58 was not covered by tests

return impl


def _lu_factor(A: np.ndarray, overwrite_a: bool = False):
"""
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
Pytensor.
"""
return linalg.lu_factor(A, overwrite_a=overwrite_a)


@overload(_lu_factor)
def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")

def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based

Check warning on line 80 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L79-L80

Added lines #L79 - L80 were not covered by tests

if INFO != 0:
raise np.linalg.LinAlgError("LU decomposition failed")
return A_copy, IPIV

Check warning on line 84 in pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py#L83-L84

Added lines #L83 - L84 were not covered by tests

return impl
Loading