Skip to content

ENH/DOC/TST: cluster.vq: use xp_capabilities #23000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2025
Merged

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented May 16, 2025

  • Add xp_capabilities to scipy.cluster.vq
  • In whiten, change default for check_finite to False on JAX and Dask
  • Test that whiten can run inside jax.jit and that it doesn't materialize the Dask graph

See Also

Demo render

image
image

@github-actions github-actions bot added scipy.stats scipy.special scipy.cluster scipy._lib scipy.constants array types Items related to array API support and input array validation (see gh-18286) Documentation Issues related to the SciPy documentation. Also check https://github.com/scipy/scipy.org enhancement A new feature or improvement maintenance Items related to regular maintenance tasks labels May 16, 2025
Comment on lines -98 to -100
@pytest.fixture
def whiten_lock(self):
return Lock()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now machinery that detects eager_warns

@@ -208,10 +212,9 @@ def vq(obs, code_book, check_finite=True):
code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
ct = xp.result_type(obs, code_book)

c_obs = xp.astype(obs, ct, copy=False)
c_code_book = xp.astype(code_book, ct, copy=False)

if xp.isdtype(ct, kind='real floating'):
Copy link
Contributor Author

@crusaderky crusaderky May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using this code path only on numpy and py_vq everywhere else, but py_vq calls spatial.distance.cdist, making the whole effort futile.

@crusaderky crusaderky added maintenance Items related to regular maintenance tasks and removed scipy.stats scipy.special scipy._lib scipy.constants maintenance Items related to regular maintenance tasks enhancement A new feature or improvement labels May 16, 2025
@crusaderky crusaderky changed the title [DNM] ENH/DOC/TST: cluster.vq: add xp_capabilities [DNM] ENH/DOC/TST: cluster.vq: use xp_capabilities May 16, 2025
@crusaderky crusaderky marked this pull request as draft May 16, 2025 20:04
@crusaderky crusaderky marked this pull request as ready for review May 16, 2025 20:14
@crusaderky crusaderky added this to the 1.16.0 milestone May 16, 2025
@crusaderky crusaderky changed the title [DNM] ENH/DOC/TST: cluster.vq: use xp_capabilities ENH/DOC/TST: cluster.vq: use xp_capabilities May 16, 2025
Comment on lines +105 to +109
# f0 f1 f2
obs = [[ 1., 1., 1.], #o0
[ 2., 2., 2.], #o1
[ 3., 3., 3.], #o2
[ 4., 4., 4.]] #o3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workaround to numpy/numpydoc#624

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Local test results showing the improvement, from pixi r test-cpu -t scipy.cluster.tests.test_vq (and test-cuda for GPU):

  • CPU on main: 1 failed, 116 passed, 37 skipped, 1 xfailed in 17.04s
  • CPU on this PR: 131 passed, 5 skipped, 1 xfailed in 15.41s
  • GPU on main: 58 passed, 97 skipped in 5.12s
  • GPU on this PR: 69 passed, 68 skipped in 5.20s

So improved array types support, and also on GPU less tests in total run. The latter needs an explanation. What happens is that the tagging is more efficient; tests aren't generated at all when that's feasible rather than generated-then-skipped. Illustrating with verbose output for test_vq:

This PR, for test-cuda:

scipy/cluster/tests/test_vq.py::TestVq::test_vq[numpy] PASSED                              [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[array_api_strict] PASSED                   [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[torch] SKIPPED (uses spatial.distance....) [ 25%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[cupy] SKIPPED (uses spatial.distance.c...) [ 26%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[jax.numpy] SKIPPED (uses spatial.dista...) [ 27%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix SKIPPED (`np.matrix` unsupporte...) [ 27%]

On main for test-cuda:

scipy/cluster/tests/test_vq.py::TestVq::test_vq[numpy] PASSED                                                                                       [ 20%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[array_api_strict] SKIPPED (`_vq` only supports NumPy backend)                                       [ 20%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[torch] SKIPPED (`_vq` only supports NumPy backend)                                                  [ 21%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[cupy] SKIPPED (`_vq` only supports NumPy backend)                                                   [ 21%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[jax.numpy] SKIPPED (`_vq` only supports NumPy backend)                                              [ 22%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[numpy] SKIPPED (`np.matrix` unsupported in array API mode)                                   [ 23%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[array_api_strict] SKIPPED (`np.matrix` unsupported in array API mode)                        [ 23%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[torch] SKIPPED (`np.matrix` unsupported in array API mode)                                   [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[cupy] SKIPPED (`np.matrix` unsupported in array API mode)                                    [ 25%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[jax.numpy] SKIPPED (`np.matrix` unsupported in array API mode)                               [ 25%]

Everything LGTM, docs render well too. Thanks @crusaderky!

@rgommers rgommers merged commit 7ef6aeb into scipy:main May 17, 2025
42 of 43 checks passed
@crusaderky crusaderky deleted the vq branch May 17, 2025 23:15
@tylerjereddy
Copy link
Contributor

@rgommers does this need a release note per discussion in scipy/scipy-stubs#526? From a surface-level scan of the prose in the PR here, it sounds like this was done backwards-compatible for NumPy, just array API support that takes some liberties for some backends, which I wouldn't normally mention I don't think. Correct analysis?

@rgommers
Copy link
Member

rgommers commented Jun 5, 2025

it sounds like this was done backwards-compatible for NumPy, just array API support that takes some liberties for some backends, which I wouldn't normally mention I don't think. Correct analysis?

Yep, 💯 correct, so no mention needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) Documentation Issues related to the SciPy documentation. Also check https://github.com/scipy/scipy.org maintenance Items related to regular maintenance tasks scipy.cluster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants