Skip to content

Add JAX dispatch for CholeskySolve Op #1491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 21, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 21, 2025

Description

Adds a JAX dispatch for the CholeskySolve Op. Nobody ever uses this function (although it's quite nice), so nobody cared that we didn't have this. Now it matters because of the rewrites introduced in #1461. Graphs that benefit from this rewrite (basically any PyMC model with an MvNormal....) will error in JAX mode because the Op is missing.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1491.org.readthedocs.build/en/1491/

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new JAX dispatch implementation for the CholeskySolve Op to support models that use Cholesky-based solutions in JAX mode. The key changes include:

  • Adding tests for the JAX implementation of cho_solve in tests/link/jax/test_slinalg.py.
  • Updating the cho_solve function signature and docstring in pytensor/tensor/slinalg.py.
  • Implementing and registering a new JAX dispatch function for CholeskySolve in pytensor/link/jax/dispatch/slinalg.py.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
tests/link/jax/test_slinalg.py Added new tests for the CholeskySolve dispatch functionality.
pytensor/tensor/slinalg.py Updated cho_solve function signature and documentation.
pytensor/link/jax/dispatch/slinalg.py Added JAX dispatch for the CholeskySolve Op.
Comments suppressed due to low confidence (1)

tests/link/jax/test_slinalg.py:338

  • [nitpick] The test function is named 'test_jax_chosolve', but the operator and primary function name is 'cho_solve'. For clarity and consistency, consider renaming the test to 'test_jax_cho_solve'.
@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jun 21, 2025
Copy link

codecov bot commented Jun 21, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@d3bbc20). Learn more about missing BASE report.
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1491   +/-   ##
=======================================
  Coverage        ?   82.01%           
=======================================
  Files           ?      214           
  Lines           ?    50439           
  Branches        ?     8907           
=======================================
  Hits            ?    41370           
  Misses          ?     6861           
  Partials        ?     2208           
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/slinalg.py 94.50% <100.00%> (ø)
pytensor/tensor/slinalg.py 93.18% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski jessegrabowski merged commit f72d7e5 into pymc-devs:main Jun 21, 2025
73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax linalg Linear algebra Op implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add jax dispatch for ChoSolve
2 participants