-
Notifications
You must be signed in to change notification settings - Fork 155
Do not raise in linalg Ops #1834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2fae1c3 to
1822b42
Compare
64b1ee6 to
a4723e6
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting changes because I don't understand why you can just delete all the QR helpers.
| from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix | ||
|
|
||
|
|
||
| def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised you can just delete all of these. They are used by the other dispatches below. If tests are passing, it might just be that coverage isn't good enough. For example xgeqrf is used by qr_full_no_pivot_impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You weren't using these overloaded functions, you were going directly to the lapack function (numba_foo, instead of foo or _foo)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You weren't using these overloaded functions, you were going directly to the lapack function from _LAPACK().numba_xgeqrf(dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah my understanding was that I needed to give overloads for those even in that case. I guess not!
a4723e6 to
5362d66
Compare
Inline when only used in one place, or remove if altogether unused
5362d66 to
9a19551
Compare
Closes #1078
Instead of adding
on_errorto every Op, this PR choses the simpler approach, just catch and return nan by default.This also allows us to simplify the numba implementations quite a lot and get rid of a couple of low-level dispatch for condition number and the like. Perhaps it even speeds up uncached compilation.
Only one Op had this special behavior: Cholesky. For helping with transitioning,
on_error="raise"is now implemented symbolically with a FutureWarning. The default switched toon_error="nan", which is the "default" in all lingalg Ops (at least the ones I covered).This is analogous to us returning 1/0->nan without any error/warning. If it's undefined/can't be computed we return nan, and let the users deal with it. Most of the times the error makes it hader to work with, not easier, as there is no symbolic
try/except, while there is symbolicisnan(x).any().I also removed the
check_finiteargument internally. Users can define it without it having any effect for API compat. The stated reason for this in Scipy is that some implementation of BLAS/Lapack can hang / crash the system with non-finite values. I'm taking a chance here, as those issues may never show up in the versions our users install / modern implementations. We can revisit this later if it proves to be still needed.I'm marking this as a major breaking change.
Finally, I've removed some lapack overloads that were never used anywhere, and inlined others that were used only in one place. I believe it makes code more readable and perhaps uncached numba compilation faster