Skip to content

Commit f48068a

Browse files
authored
Use LAPACK functions for cho_solve, lu_factor, solve_triangular (#1605)
* Use lapack instead of `scipy_linalg.cho_solve` * Use lapack instead of `scipy_linalg.lu_factor` * Use lapack instead of `scipy_linalg.solve_triangular` * Add empty test for lu_factor * Tidy imports * remove ndim check
1 parent ac6c4e0 commit f48068a

File tree

2 files changed

+140
-21
lines changed

2 files changed

+140
-21
lines changed

pytensor/tensor/slinalg.py

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import scipy.linalg as scipy_linalg
99
from numpy.exceptions import ComplexWarning
10-
from scipy.linalg import get_lapack_funcs
10+
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs
1111

1212
import pytensor
1313
from pytensor import ifelse
@@ -384,15 +384,28 @@ def make_node(self, *inputs):
384384
return Apply(self, [A, b], [out])
385385

386386
def perform(self, node, inputs, output_storage):
387-
C, b = inputs
388-
rval = scipy_linalg.cho_solve(
389-
(C, self.lower),
390-
b,
391-
check_finite=self.check_finite,
392-
overwrite_b=self.overwrite_b,
393-
)
387+
c, b = inputs
388+
389+
(potrs,) = get_lapack_funcs(("potrs",), (c, b))
394390

395-
output_storage[0][0] = rval
391+
if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()):
392+
raise ValueError("array must not contain infs or NaNs")
393+
394+
if c.shape[0] != c.shape[1]:
395+
raise ValueError("The factored matrix c is not square.")
396+
if c.shape[1] != b.shape[0]:
397+
raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})")
398+
399+
# Quick return for empty arrays
400+
if b.size == 0:
401+
output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype)
402+
return
403+
404+
x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b)
405+
if info != 0:
406+
raise ValueError(f"illegal value in {-info}th argument of internal potrs")
407+
408+
output_storage[0][0] = x
396409

397410
def L_op(self, *args, **kwargs):
398411
# TODO: Base impl should work, let's try it
@@ -696,9 +709,27 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
696709
def perform(self, node, inputs, outputs):
697710
A = inputs[0]
698711

699-
LU, p = scipy_linalg.lu_factor(
700-
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
701-
)
712+
# Quick return for empty arrays
713+
if A.size == 0:
714+
outputs[0][0] = np.empty_like(A)
715+
outputs[1][0] = np.array([], dtype=np.int32)
716+
return
717+
718+
if self.check_finite and not np.isfinite(A).all():
719+
raise ValueError("array must not contain infs or NaNs")
720+
721+
(getrf,) = get_lapack_funcs(("getrf",), (A,))
722+
LU, p, info = getrf(A, overwrite_a=self.overwrite_a)
723+
if info < 0:
724+
raise ValueError(
725+
f"illegal value in {-info}th argument of internal getrf (lu_factor)"
726+
)
727+
if info > 0:
728+
warnings.warn(
729+
f"Diagonal number {info} is exactly zero. Singular matrix.",
730+
LinAlgWarning,
731+
stacklevel=2,
732+
)
702733

703734
outputs[0][0] = LU
704735
outputs[1][0] = p
@@ -865,15 +896,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs):
865896

866897
def perform(self, node, inputs, outputs):
867898
A, b = inputs
868-
outputs[0][0] = scipy_linalg.solve_triangular(
869-
A,
870-
b,
871-
lower=self.lower,
872-
trans=0,
873-
unit_diagonal=self.unit_diagonal,
874-
check_finite=self.check_finite,
875-
overwrite_b=self.overwrite_b,
876-
)
899+
900+
if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()):
901+
raise ValueError("array must not contain infs or NaNs")
902+
903+
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
904+
raise ValueError("expected square matrix")
905+
906+
if A.shape[0] != b.shape[0]:
907+
raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible")
908+
909+
(trtrs,) = get_lapack_funcs(("trtrs",), (A, b))
910+
911+
# Quick return for empty arrays
912+
if b.size == 0:
913+
outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype)
914+
return
915+
916+
if A.flags["F_CONTIGUOUS"]:
917+
x, info = trtrs(
918+
A,
919+
b,
920+
overwrite_b=self.overwrite_b,
921+
lower=self.lower,
922+
trans=0,
923+
unitdiag=self.unit_diagonal,
924+
)
925+
else:
926+
# transposed system is solved since trtrs expects Fortran ordering
927+
x, info = trtrs(
928+
A.T,
929+
b,
930+
overwrite_b=self.overwrite_b,
931+
lower=not self.lower,
932+
trans=1,
933+
unitdiag=self.unit_diagonal,
934+
)
935+
936+
if info > 0:
937+
raise LinAlgError(
938+
f"singular matrix: resolution failed at diagonal {info-1}"
939+
)
940+
elif info < 0:
941+
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
942+
943+
outputs[0][0] = x
877944

878945
def L_op(self, inputs, outputs, output_gradients):
879946
res = super().L_op(inputs, outputs, output_gradients)

tests/tensor/test_slinalg.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,31 @@ def solve_op(A, b):
513513

514514
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
515515

516+
def test_solve_triangular_empty(self):
517+
rng = np.random.default_rng(utt.fetch_seed())
518+
A = pt.tensor("A", shape=(5, 5))
519+
b = pt.tensor("b", shape=(5, 0))
520+
521+
A_val = rng.random((5, 5)).astype(config.floatX)
522+
b_empty = np.empty([5, 0], dtype=config.floatX)
523+
524+
A_func = functools.partial(self.A_func, lower=True, unit_diagonal=True)
525+
526+
x = solve_triangular(
527+
A_func(A),
528+
b,
529+
lower=True,
530+
trans=0,
531+
unit_diagonal=True,
532+
b_ndim=len((5, 0)),
533+
)
534+
535+
f = function([A, b], x)
536+
537+
res = f(A_val, b_empty)
538+
assert res.size == 0
539+
assert res.dtype == config.floatX
540+
516541

517542
class TestCholeskySolve(utt.InferShapeTester):
518543
def setup_method(self):
@@ -797,6 +822,18 @@ def test_lu_factor():
797822
)
798823

799824

825+
def test_lu_factor_empty():
826+
A = matrix()
827+
f = function([A], lu_factor(A))
828+
829+
A_empty = np.empty([0, 0], dtype=config.floatX)
830+
LU, pt_p_idx = f(A_empty)
831+
832+
assert LU.size == 0
833+
assert LU.dtype == config.floatX
834+
assert pt_p_idx.size == 0
835+
836+
800837
def test_cho_solve():
801838
rng = np.random.default_rng(utt.fetch_seed())
802839
A = matrix()
@@ -814,6 +851,21 @@ def test_cho_solve():
814851
)
815852

816853

854+
def test_cho_solve_empty():
855+
rng = np.random.default_rng(utt.fetch_seed())
856+
A = matrix()
857+
b = matrix()
858+
y = cho_solve((A, True), b)
859+
cho_solve_lower_func = function([A, b], y)
860+
861+
A_empty = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX))
862+
b_empty = np.empty([5, 0], dtype=config.floatX)
863+
864+
res = cho_solve_lower_func(A_empty, b_empty)
865+
assert res.size == 0
866+
assert res.dtype == config.floatX
867+
868+
817869
def test_expm():
818870
rng = np.random.default_rng(utt.fetch_seed())
819871
A = rng.standard_normal((5, 5)).astype(config.floatX)

0 commit comments

Comments
 (0)