Skip to content

Use lapack func instead of scipy.linalg.cholesky #1487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
self,
*,
lower: bool = True,
check_finite: bool = True,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
Expand Down Expand Up @@ -67,29 +67,55 @@
def perform(self, node, inputs, outputs):
[x] = inputs
[out] = outputs
try:
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy_linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy_linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)

except scipy_linalg.LinAlgError:
if self.on_error == "raise":
raise
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,))

# Quick return for square empty array
if x.size == 0:
out[0] = np.empty_like(x, dtype=potrf.dtype)
return

if self.check_finite and not np.isfinite(x).all():
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype)
return

Check warning on line 81 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L80-L81

Added lines #L80 - L81 were not covered by tests
else:
raise ValueError("array must not contain infs or NaNs")

# Squareness check
if x.shape[0] != x.shape[1]:
raise ValueError(

Check warning on line 87 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L87

Added line #L87 was not covered by tests
"Input array is expected to be square but has " f"the shape: {x.shape}."
)

# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
x = x.T
lower = not self.lower
overwrite_a = True
else:
lower = self.lower
overwrite_a = self.overwrite_a

c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)

if info != 0:
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
elif info > 0:
raise scipy_linalg.LinAlgError(
f"{info}-th leading minor of the array is not positive definite"
)
elif info < 0:
raise ValueError(

Check warning on line 112 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L112

Added line #L112 was not covered by tests
f"LAPACK reported an illegal value in {-info}-th argument "
f'on entry to "POTRF".'
)
else:
# Transpose result if input was transposed
out[0] = c.T if c_contiguous_input else c

def L_op(self, inputs, outputs, gradients):
"""
Expand Down Expand Up @@ -201,7 +227,9 @@

"""

return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def test_cholesky_raises_on_nan_input():

x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")

with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
Expand Down
20 changes: 20 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def test_cholesky():
check_upper_triangular(pd, ch_f)


def test_cholesky_performance(benchmark):
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((10, 10)).astype(config.floatX)
pd = np.dot(r, r.T)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
benchmark(ch_f, pd)


def test_cholesky_empty():
empty = np.empty([0, 0], dtype=config.floatX)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
ch = ch_f(empty)
assert ch.size == 0
assert ch.dtype == config.floatX


def test_cholesky_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
Expand Down