Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 8, 2026

Closes #1078

Instead of adding on_error to every Op, this PR choses the simpler approach, just catch and return nan by default.

This also allows us to simplify the numba implementations quite a lot and get rid of a couple of low-level dispatch for condition number and the like. Perhaps it even speeds up uncached compilation.

Only one Op had this special behavior: Cholesky. For helping with transitioning, on_error="raise" is now implemented symbolically with a FutureWarning. The default switched to on_error="nan", which is the "default" in all lingalg Ops (at least the ones I covered).

This is analogous to us returning 1/0->nan without any error/warning. If it's undefined/can't be computed we return nan, and let the users deal with it. Most of the times the error makes it hader to work with, not easier, as there is no symbolic try/except, while there is symbolic isnan(x).any().


I also removed the check_finite argument internally. Users can define it without it having any effect for API compat. The stated reason for this in Scipy is that some implementation of BLAS/Lapack can hang / crash the system with non-finite values. I'm taking a chance here, as those issues may never show up in the versions our users install / modern implementations. We can revisit this later if it proves to be still needed.

I'm marking this as a major breaking change.


Finally, I've removed some lapack overloads that were never used anywhere, and inlined others that were used only in one place. I believe it makes code more readable and perhaps uncached numba compilation faster

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes because I don't understand why you can just delete all the QR helpers.

from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised you can just delete all of these. They are used by the other dispatches below. If tests are passing, it might just be that coverage isn't good enough. For example xgeqrf is used by qr_full_no_pivot_impl

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You weren't using these overloaded functions, you were going directly to the lapack function (numba_foo, instead of foo or _foo)

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You weren't using these overloaded functions, you were going directly to the lapack function from _LAPACK().numba_xgeqrf(dtype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah my understanding was that I needed to give overloads for those even in that case. I guess not!

Inline when only used in one place, or remove if altogether unused
@ricardoV94 ricardoV94 merged commit e75bbb2 into pymc-devs:main Jan 11, 2026
116 of 118 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add on_error argument to linalg functions

2 participants