|
7 | 7 | import numpy as np |
8 | 8 | import scipy.linalg as scipy_linalg |
9 | 9 | from numpy.exceptions import ComplexWarning |
10 | | -from scipy.linalg import get_lapack_funcs |
| 10 | +from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs |
11 | 11 |
|
12 | 12 | import pytensor |
13 | 13 | from pytensor import ifelse |
@@ -384,15 +384,28 @@ def make_node(self, *inputs): |
384 | 384 | return Apply(self, [A, b], [out]) |
385 | 385 |
|
386 | 386 | def perform(self, node, inputs, output_storage): |
387 | | - C, b = inputs |
388 | | - rval = scipy_linalg.cho_solve( |
389 | | - (C, self.lower), |
390 | | - b, |
391 | | - check_finite=self.check_finite, |
392 | | - overwrite_b=self.overwrite_b, |
393 | | - ) |
| 387 | + c, b = inputs |
| 388 | + |
| 389 | + (potrs,) = get_lapack_funcs(("potrs",), (c, b)) |
394 | 390 |
|
395 | | - output_storage[0][0] = rval |
| 391 | + if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()): |
| 392 | + raise ValueError("array must not contain infs or NaNs") |
| 393 | + |
| 394 | + if c.shape[0] != c.shape[1]: |
| 395 | + raise ValueError("The factored matrix c is not square.") |
| 396 | + if c.shape[1] != b.shape[0]: |
| 397 | + raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})") |
| 398 | + |
| 399 | + # Quick return for empty arrays |
| 400 | + if b.size == 0: |
| 401 | + output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype) |
| 402 | + return |
| 403 | + |
| 404 | + x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b) |
| 405 | + if info != 0: |
| 406 | + raise ValueError(f"illegal value in {-info}th argument of internal potrs") |
| 407 | + |
| 408 | + output_storage[0][0] = x |
396 | 409 |
|
397 | 410 | def L_op(self, *args, **kwargs): |
398 | 411 | # TODO: Base impl should work, let's try it |
@@ -696,9 +709,27 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": |
696 | 709 | def perform(self, node, inputs, outputs): |
697 | 710 | A = inputs[0] |
698 | 711 |
|
699 | | - LU, p = scipy_linalg.lu_factor( |
700 | | - A, overwrite_a=self.overwrite_a, check_finite=self.check_finite |
701 | | - ) |
| 712 | + # Quick return for empty arrays |
| 713 | + if A.size == 0: |
| 714 | + outputs[0][0] = np.empty_like(A) |
| 715 | + outputs[1][0] = np.array([], dtype=np.int32) |
| 716 | + return |
| 717 | + |
| 718 | + if self.check_finite and not np.isfinite(A).all(): |
| 719 | + raise ValueError("array must not contain infs or NaNs") |
| 720 | + |
| 721 | + (getrf,) = get_lapack_funcs(("getrf",), (A,)) |
| 722 | + LU, p, info = getrf(A, overwrite_a=self.overwrite_a) |
| 723 | + if info < 0: |
| 724 | + raise ValueError( |
| 725 | + f"illegal value in {-info}th argument of internal getrf (lu_factor)" |
| 726 | + ) |
| 727 | + if info > 0: |
| 728 | + warnings.warn( |
| 729 | + f"Diagonal number {info} is exactly zero. Singular matrix.", |
| 730 | + LinAlgWarning, |
| 731 | + stacklevel=2, |
| 732 | + ) |
702 | 733 |
|
703 | 734 | outputs[0][0] = LU |
704 | 735 | outputs[1][0] = p |
@@ -865,15 +896,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs): |
865 | 896 |
|
866 | 897 | def perform(self, node, inputs, outputs): |
867 | 898 | A, b = inputs |
868 | | - outputs[0][0] = scipy_linalg.solve_triangular( |
869 | | - A, |
870 | | - b, |
871 | | - lower=self.lower, |
872 | | - trans=0, |
873 | | - unit_diagonal=self.unit_diagonal, |
874 | | - check_finite=self.check_finite, |
875 | | - overwrite_b=self.overwrite_b, |
876 | | - ) |
| 899 | + |
| 900 | + if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): |
| 901 | + raise ValueError("array must not contain infs or NaNs") |
| 902 | + |
| 903 | + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: |
| 904 | + raise ValueError("expected square matrix") |
| 905 | + |
| 906 | + if A.shape[0] != b.shape[0]: |
| 907 | + raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") |
| 908 | + |
| 909 | + (trtrs,) = get_lapack_funcs(("trtrs",), (A, b)) |
| 910 | + |
| 911 | + # Quick return for empty arrays |
| 912 | + if b.size == 0: |
| 913 | + outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) |
| 914 | + return |
| 915 | + |
| 916 | + if A.flags["F_CONTIGUOUS"]: |
| 917 | + x, info = trtrs( |
| 918 | + A, |
| 919 | + b, |
| 920 | + overwrite_b=self.overwrite_b, |
| 921 | + lower=self.lower, |
| 922 | + trans=0, |
| 923 | + unitdiag=self.unit_diagonal, |
| 924 | + ) |
| 925 | + else: |
| 926 | + # transposed system is solved since trtrs expects Fortran ordering |
| 927 | + x, info = trtrs( |
| 928 | + A.T, |
| 929 | + b, |
| 930 | + overwrite_b=self.overwrite_b, |
| 931 | + lower=not self.lower, |
| 932 | + trans=1, |
| 933 | + unitdiag=self.unit_diagonal, |
| 934 | + ) |
| 935 | + |
| 936 | + if info > 0: |
| 937 | + raise LinAlgError( |
| 938 | + f"singular matrix: resolution failed at diagonal {info-1}" |
| 939 | + ) |
| 940 | + elif info < 0: |
| 941 | + raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") |
| 942 | + |
| 943 | + outputs[0][0] = x |
877 | 944 |
|
878 | 945 | def L_op(self, inputs, outputs, output_gradients): |
879 | 946 | res = super().L_op(inputs, outputs, output_gradients) |
|
0 commit comments