Skip to content

Replace lu_solve with an Op that directly calls getrs #1480

Closed
@jessegrabowski

Description

@jessegrabowski

Description

When we implemented lu_solve in #1218 , we waffled back and forth on whether to make a new Op that would wrap and directly call getrs, or to make a function from existing primitives (e.g. two calls to solve_triangular).

Ultimately we went with the function approach, following the JAX implementation. But that's turning out to have a performance cost. We should go back and wrap getrs to get maximum speed out of solves.

The "hardest" part of the issue will be working out the backward sensitivity for the gradient. It should be the same as solve, but we have to keep in mind that there's an extra P matrix floating around that has to be accounted for.

But also note the gradients are just a nice-to-have, not a must-have. We're mostly using this Op for rewriting, and it is seldom differentiated itself in first order. But if possible I'd rather have it than not.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions