Skip to content

Decompose Tridiagonal Solve into core steps #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"fusion",
"inplace",
"scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
)
Expand Down
104 changes: 95 additions & 9 deletions pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numpy import ndarray
from scipy import linalg

from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
Expand All @@ -20,6 +21,10 @@
_solve_check,
_trans_char_to_int,
)
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
SolveLUFactorTridiagonal,
)


@numba_njit
Expand All @@ -34,7 +39,12 @@


def _gttrf(
dl: ndarray, d: ndarray, du: ndarray
dl: ndarray,
d: ndarray,
du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
"""Placeholder for LU factorization of tridiagonal matrix."""
return # type: ignore
Expand All @@ -45,8 +55,12 @@
dl: ndarray,
d: ndarray,
du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> Callable[
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]
[ndarray, ndarray, ndarray, bool, bool, bool],
tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int],
]:
ensure_lapack()
_check_scipy_linalg_matrix(dl, "gttrf")
Expand All @@ -60,12 +74,24 @@
dl: ndarray,
d: ndarray,
du: ndarray,
overwrite_dl: bool,
overwrite_d: bool,
overwrite_du: bool,
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
n = np.int32(d.shape[-1])
ipiv = np.empty(n, dtype=np.int32)
du2 = np.empty(n - 2, dtype=dtype)
info = val_to_int_ptr(0)

if not overwrite_dl or not dl.flags.f_contiguous:
dl = dl.copy()

Check warning on line 87 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L87

Added line #L87 was not covered by tests

if not overwrite_d or not d.flags.f_contiguous:
d = d.copy()

Check warning on line 90 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L90

Added line #L90 was not covered by tests

if not overwrite_du or not du.flags.f_contiguous:
du = du.copy()

Check warning on line 93 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L93

Added line #L93 was not covered by tests

numba_gttrf(
val_to_int_ptr(n),
dl.view(w_type).ctypes,
Expand Down Expand Up @@ -133,10 +159,23 @@
nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
info = val_to_int_ptr(0)

if overwrite_b and b.flags.f_contiguous:
b_copy = b
else:
b_copy = _copy_to_fortran_order_even_if_1d(b)
if not overwrite_b or not b.flags.f_contiguous:
b = _copy_to_fortran_order_even_if_1d(b)

Check warning on line 163 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L163

Added line #L163 was not covered by tests

if not dl.flags.f_contiguous:
dl = dl.copy()

Check warning on line 166 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L166

Added line #L166 was not covered by tests

if not d.flags.f_contiguous:
d = d.copy()

Check warning on line 169 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L169

Added line #L169 was not covered by tests

if not du.flags.f_contiguous:
du = du.copy()

Check warning on line 172 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L172

Added line #L172 was not covered by tests

if not du2.flags.f_contiguous:
du2 = du2.copy()

Check warning on line 175 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L175

Added line #L175 was not covered by tests

if not ipiv.flags.f_contiguous:
ipiv = ipiv.copy()

Check warning on line 178 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L178

Added line #L178 was not covered by tests

numba_gttrs(
val_to_int_ptr(_trans_char_to_int(trans)),
Expand All @@ -147,12 +186,12 @@
du.view(w_type).ctypes,
du2.view(w_type).ctypes,
ipiv.ctypes,
b_copy.view(w_type).ctypes,
b.view(w_type).ctypes,
val_to_int_ptr(n),
info,
)

return b_copy, int_ptr_to_val(info)
return b, int_ptr_to_val(info)

Check warning on line 194 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L194

Added line #L194 was not covered by tests

return impl

Expand Down Expand Up @@ -283,7 +322,9 @@

anorm = tridiagonal_norm(du, d, dl)

dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
dl, d, du, du2, IPIV, INFO = _gttrf(

Check warning on line 325 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L325

Added line #L325 was not covered by tests
dl, d, du, overwrite_dl=True, overwrite_d=True, overwrite_du=True
)
_solve_check(n, INFO)

X, INFO = _gttrs(
Expand All @@ -297,3 +338,48 @@
return X

return impl


@numba_funcify.register(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_dl = op.overwrite_dl
overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du

@numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du):
dl, d, du, du2, ipiv, _ = _gttrf(

Check warning on line 351 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L351

Added line #L351 was not covered by tests
dl,
d,
du,
overwrite_dl=overwrite_dl,
overwrite_d=overwrite_d,
overwrite_du=overwrite_du,
)
return dl, d, du, du2, ipiv

Check warning on line 359 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L359

Added line #L359 was not covered by tests

return lu_factor_tridiagonal


@numba_funcify.register(SolveLUFactorTridiagonal)
def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs
):
overwrite_b = op.overwrite_b
transposed = op.transposed

@numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
x, _ = _gttrs(

Check warning on line 373 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L373

Added line #L373 was not covered by tests
dl,
d,
du,
du2,
ipiv,
b,
overwrite_b=overwrite_b,
trans=transposed,
)
return x

Check warning on line 383 in pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py#L383

Added line #L383 was not covered by tests

return solve_lu_factor_tridiagonal
60 changes: 55 additions & 5 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from collections.abc import Container
from copy import copy

from pytensor.compile import optdb
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.tridiagonal import (
tridiagonal_lu_factor,
tridiagonal_lu_solve,
)
from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
Expand All @@ -17,18 +22,32 @@
def decompose_A(A, assume_a, check_finite):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
else:
raise NotImplementedError


def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
if core_solve_op.assume_a == "gen":
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
return lu_solve(
A_decomp,
b,
b_ndim=b_ndim,
trans=transposed,
b_ndim=core_solve_op.b_ndim,
check_finite=core_solve_op.check_finite,
check_finite=check_finite,
)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve(
A_decomp,
b,
b_ndim=b_ndim,
transposed=transposed,
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
)


Expand All @@ -207,3 +228,32 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
"scan_pushout",
position=2,
)


@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})


optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
)


scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
)
Loading