Skip to content

Conversation

@vfdev-5
Copy link
Member

@vfdev-5 vfdev-5 commented Jul 5, 2024

Fixes #22137

Description:

  • Fixed issue in jnp.argpartition and jnp.partition with unsigned integers with zeros

Thanks @pearu for the trick with -lax.top_k(-(arr + 1), kth + 1)[0] - 1

Details on the problem with unsigned integers:

import jax.numpy as jnp
from jax import lax

a = jnp.array([0, 1, 2, 254, 255], dtype=jnp.uint8)
print(-a)   # [  0 255 254   2   1]

# On main now:
print(lax.top_k(-a, 3)[0])  # [255 254   2]
# but it should be [0  255 254]
print(-lax.top_k(-a, 3)[0])  # [  1   2 254]
# but it should be [0  1   2]

# On this PR using the trick suggested by Pearu
print(lax.top_k(-(a + 1), 3)[0])  # [255 254 253]
print(-lax.top_k(-(a + 1), 3)[0] - 1)  # [0 1 2]

@vfdev-5 vfdev-5 requested a review from jakevdp July 5, 2024 13:52
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
self.assertArraysEqual(actual, expected)

def _assertSamePartionedArrays(self, jnp_output, np_output, axis, kth, shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the assertion part between testPartition and testArgpartition functions and reused this method in added tests. If there can be another way of doing the tests I'm happy to implement it

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 8, 2024
@vfdev-5 vfdev-5 marked this pull request as ready for review July 8, 2024 05:36
@copybara-service copybara-service bot merged commit 1af93ab into jax-ml:main Jul 8, 2024
@vfdev-5 vfdev-5 deleted the fix-22137-partition-on-unsigned-dtypes branch July 8, 2024 12:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong result for unsigned dtype input into jax.numpy.partition and jax.numpy.argpartition

4 participants