Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6400,7 +6400,11 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
kth = _canonicalize_axis(kth, arr.shape[axis])

arr = swapaxes(arr, axis, -1)
bottom = -lax.top_k(-arr, kth + 1)[0]
if dtypes.isdtype(arr.dtype, "unsigned integer"):
# Here, we apply a trick to handle correctly 0 values for unsigned integers
bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1
else:
bottom = -lax.top_k(-arr, kth + 1)[0]
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
return swapaxes(out, -1, axis)
Expand Down Expand Up @@ -6467,7 +6471,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
kth = _canonicalize_axis(kth, arr.shape[axis])

arr = swapaxes(arr, axis, -1)
bottom_ind = lax.top_k(-arr, kth + 1)[1]
if dtypes.isdtype(arr.dtype, "unsigned integer"):
# Here, we apply a trick to handle correctly 0 values for unsigned integers
bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1]
else:
bottom_ind = lax.top_k(-arr, kth + 1)[1]

# To avoid issues with duplicate values, we compute the top indices via a proxy
set_to_zero = lambda a, i: a.at[i].set(0)
Expand Down
76 changes: 53 additions & 23 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4254,6 +4254,20 @@ def testArgsortUnstable(self, dtype, shape, axis, descending):
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
self.assertArraysEqual(actual, expected)

def _assertSamePartionedArrays(self, jnp_output, np_output, axis, kth, shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the assertion part between testPartition and testArgpartition functions and reused this method in added tests. If there can be another way of doing the tests I'm happy to implement it

# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_output, axis=axis, index=kth),
lax.index_in_dim(np_output, axis=axis, index=kth))

# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))

@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
Expand All @@ -4266,19 +4280,21 @@ def testPartition(self, shape, dtype, axis, kth):
arg = rng(shape, dtype)
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)

# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_output, axis=axis, index=kth),
lax.index_in_dim(np_output, axis=axis, index=kth))

# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
@jtu.sample_product(
kth=range(10),
dtype=unsigned_dtypes,
)
def testPartitionUnsignedWithZeros(self, kth, dtype):
# https://github.com/google/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype)
axis = -1
shape = arg.shape
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)

@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
Expand All @@ -4305,19 +4321,33 @@ def testArgpartition(self, shape, dtype, axis, kth):
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)

# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_values, axis=axis, index=kth),
lax.index_in_dim(np_values, axis=axis, index=kth))
@jtu.sample_product(
kth=range(10),
dtype=unsigned_dtypes,
)
def testArgpartitionUnsignedWithZeros(self, kth, dtype):
# https://github.com/google/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype)
axis = -1
shape = arg.shape
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
np_output = np.argpartition(arg, axis=axis, kth=kth)

# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
# Assert that all indices are present
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)

# Because JAX & numpy may treat duplicates differently, we must compare values
# rather than indices.
getvals = lambda x, ind: x[ind]
for ax in range(arg.ndim):
if ax != range(arg.ndim)[axis]:
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)

@jtu.sample_product(
[dict(shifts=shifts, axis=axis)
Expand Down