-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH/DOC/TST: cluster.vq
: use xp_capabilities
#23000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pytest.fixture | ||
def whiten_lock(self): | ||
return Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is now machinery that detects eager_warns
@@ -208,10 +212,9 @@ def vq(obs, code_book, check_finite=True): | |||
code_book = _asarray(code_book, xp=xp, check_finite=check_finite) | |||
ct = xp.result_type(obs, code_book) | |||
|
|||
c_obs = xp.astype(obs, ct, copy=False) | |||
c_code_book = xp.astype(code_book, ct, copy=False) | |||
|
|||
if xp.isdtype(ct, kind='real floating'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using this code path only on numpy and py_vq
everywhere else, but py_vq
calls spatial.distance.cdist
, making the whole effort futile.
cluster.vq
: add xp_capabilities
cluster.vq
: use xp_capabilities
cluster.vq
: use xp_capabilities
cluster.vq
: use xp_capabilities
# f0 f1 f2 | ||
obs = [[ 1., 1., 1.], #o0 | ||
[ 2., 2., 2.], #o1 | ||
[ 3., 3., 3.], #o2 | ||
[ 4., 4., 4.]] #o3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workaround to numpy/numpydoc#624
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Local test results showing the improvement, from pixi r test-cpu -t scipy.cluster.tests.test_vq
(and test-cuda
for GPU):
- CPU on
main
: 1 failed, 116 passed, 37 skipped, 1 xfailed in 17.04s - CPU on this PR: 131 passed, 5 skipped, 1 xfailed in 15.41s
- GPU on
main
: 58 passed, 97 skipped in 5.12s - GPU on this PR: 69 passed, 68 skipped in 5.20s
So improved array types support, and also on GPU less tests in total run. The latter needs an explanation. What happens is that the tagging is more efficient; tests aren't generated at all when that's feasible rather than generated-then-skipped. Illustrating with verbose output for test_vq
:
This PR, for test-cuda
:
scipy/cluster/tests/test_vq.py::TestVq::test_vq[numpy] PASSED [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[array_api_strict] PASSED [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[torch] SKIPPED (uses spatial.distance....) [ 25%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[cupy] SKIPPED (uses spatial.distance.c...) [ 26%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[jax.numpy] SKIPPED (uses spatial.dista...) [ 27%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix SKIPPED (`np.matrix` unsupporte...) [ 27%]
On main
for test-cuda
:
scipy/cluster/tests/test_vq.py::TestVq::test_vq[numpy] PASSED [ 20%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[array_api_strict] SKIPPED (`_vq` only supports NumPy backend) [ 20%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[torch] SKIPPED (`_vq` only supports NumPy backend) [ 21%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[cupy] SKIPPED (`_vq` only supports NumPy backend) [ 21%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq[jax.numpy] SKIPPED (`_vq` only supports NumPy backend) [ 22%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[numpy] SKIPPED (`np.matrix` unsupported in array API mode) [ 23%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[array_api_strict] SKIPPED (`np.matrix` unsupported in array API mode) [ 23%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[torch] SKIPPED (`np.matrix` unsupported in array API mode) [ 24%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[cupy] SKIPPED (`np.matrix` unsupported in array API mode) [ 25%]
scipy/cluster/tests/test_vq.py::TestVq::test_vq_matrix[jax.numpy] SKIPPED (`np.matrix` unsupported in array API mode) [ 25%]
Everything LGTM, docs render well too. Thanks @crusaderky!
@rgommers does this need a release note per discussion in scipy/scipy-stubs#526? From a surface-level scan of the prose in the PR here, it sounds like this was done backwards-compatible for NumPy, just array API support that takes some liberties for some backends, which I wouldn't normally mention I don't think. Correct analysis? |
Yep, 💯 correct, so no mention needed |
xp_capabilities
toscipy.cluster.vq
whiten
, change default forcheck_finite
to False on JAX and Daskwhiten
can run insidejax.jit
and that it doesn't materialize the Dask graphSee Also
cluster.hierarchy
: usexp_capabilities
#22960Demo render