Description
Description
When we implemented lu_solve
in #1218 , we waffled back and forth on whether to make a new Op
that would wrap and directly call getrs
, or to make a function from existing primitives (e.g. two calls to solve_triangular
).
Ultimately we went with the function approach, following the JAX implementation. But that's turning out to have a performance cost. We should go back and wrap getrs
to get maximum speed out of solves.
The "hardest" part of the issue will be working out the backward sensitivity for the gradient. It should be the same as solve, but we have to keep in mind that there's an extra P
matrix floating around that has to be accounted for.
But also note the gradients are just a nice-to-have, not a must-have. We're mostly using this Op
for rewriting, and it is seldom differentiated itself in first order. But if possible I'd rather have it than not.