Skip to content

Commit 2780b61

Browse files
committed
Fixed issue in jnp.argpartition and jnp.partition
1 parent 061ccd4 commit 2780b61

File tree

2 files changed

+63
-25
lines changed

2 files changed

+63
-25
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6400,7 +6400,11 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
64006400
kth = _canonicalize_axis(kth, arr.shape[axis])
64016401

64026402
arr = swapaxes(arr, axis, -1)
6403-
bottom = -lax.top_k(-arr, kth + 1)[0]
6403+
if dtypes.isdtype(arr.dtype, "unsigned integer"):
6404+
# Here, we apply a trick to handle correctly 0 values for unsigned integers
6405+
bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1
6406+
else:
6407+
bottom = -lax.top_k(-arr, kth + 1)[0]
64046408
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
64056409
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
64066410
return swapaxes(out, -1, axis)
@@ -6467,7 +6471,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
64676471
kth = _canonicalize_axis(kth, arr.shape[axis])
64686472

64696473
arr = swapaxes(arr, axis, -1)
6470-
bottom_ind = lax.top_k(-arr, kth + 1)[1]
6474+
if dtypes.isdtype(arr.dtype, "unsigned integer"):
6475+
# Here, we apply a trick to handle correctly 0 values for unsigned integers
6476+
bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1]
6477+
else:
6478+
bottom_ind = lax.top_k(-arr, kth + 1)[1]
64716479

64726480
# To avoid issues with duplicate values, we compute the top indices via a proxy
64736481
set_to_zero = lambda a, i: a.at[i].set(0)

tests/lax_numpy_test.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,6 +4254,20 @@ def testArgsortUnstable(self, dtype, shape, axis, descending):
42544254
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
42554255
self.assertArraysEqual(actual, expected)
42564256

4257+
def _assertSamePartionedArrays(self, jnp_output, np_output, axis, kth, shape):
4258+
# Assert that pivot point is equal:
4259+
self.assertArraysEqual(
4260+
lax.index_in_dim(jnp_output, axis=axis, index=kth),
4261+
lax.index_in_dim(np_output, axis=axis, index=kth))
4262+
4263+
# Assert remaining values are correctly partitioned:
4264+
self.assertArraysEqual(
4265+
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4266+
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4267+
self.assertArraysEqual(
4268+
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4269+
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4270+
42574271
@jtu.sample_product(
42584272
[{'shape': shape, 'axis': axis, 'kth': kth}
42594273
for shape in nonzerodim_shapes
@@ -4266,19 +4280,21 @@ def testPartition(self, shape, dtype, axis, kth):
42664280
arg = rng(shape, dtype)
42674281
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
42684282
np_output = np.partition(arg, axis=axis, kth=kth)
4283+
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)
42694284

4270-
# Assert that pivot point is equal:
4271-
self.assertArraysEqual(
4272-
lax.index_in_dim(jnp_output, axis=axis, index=kth),
4273-
lax.index_in_dim(np_output, axis=axis, index=kth))
4274-
4275-
# Assert remaining values are correctly partitioned:
4276-
self.assertArraysEqual(
4277-
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4278-
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4279-
self.assertArraysEqual(
4280-
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4281-
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4285+
@jtu.sample_product(
4286+
kth=range(10),
4287+
dtype=unsigned_dtypes,
4288+
)
4289+
def testPartitionUnsignedWithZeros(self, kth, dtype):
4290+
# https://github.com/google/jax/issues/22137
4291+
max_val = np.iinfo(dtype).max
4292+
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype)
4293+
axis = -1
4294+
shape = arg.shape
4295+
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
4296+
np_output = np.partition(arg, axis=axis, kth=kth)
4297+
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)
42824298

42834299
@jtu.sample_product(
42844300
[{'shape': shape, 'axis': axis, 'kth': kth}
@@ -4305,19 +4321,33 @@ def testArgpartition(self, shape, dtype, axis, kth):
43054321
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
43064322
jnp_values = getvals(arg, jnp_output)
43074323
np_values = getvals(arg, np_output)
4324+
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)
43084325

4309-
# Assert that pivot point is equal:
4310-
self.assertArraysEqual(
4311-
lax.index_in_dim(jnp_values, axis=axis, index=kth),
4312-
lax.index_in_dim(np_values, axis=axis, index=kth))
4326+
@jtu.sample_product(
4327+
kth=range(10),
4328+
dtype=unsigned_dtypes,
4329+
)
4330+
def testArgpartitionUnsignedWithZeros(self, kth, dtype):
4331+
# https://github.com/google/jax/issues/22137
4332+
max_val = np.iinfo(dtype).max
4333+
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype)
4334+
axis = -1
4335+
shape = arg.shape
4336+
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
4337+
np_output = np.argpartition(arg, axis=axis, kth=kth)
43134338

4314-
# Assert remaining values are correctly partitioned:
4315-
self.assertArraysEqual(
4316-
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4317-
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4318-
self.assertArraysEqual(
4319-
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4320-
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4339+
# Assert that all indices are present
4340+
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
4341+
4342+
# Because JAX & numpy may treat duplicates differently, we must compare values
4343+
# rather than indices.
4344+
getvals = lambda x, ind: x[ind]
4345+
for ax in range(arg.ndim):
4346+
if ax != range(arg.ndim)[axis]:
4347+
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
4348+
jnp_values = getvals(arg, jnp_output)
4349+
np_values = getvals(arg, np_output)
4350+
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)
43214351

43224352
@jtu.sample_product(
43234353
[dict(shifts=shifts, axis=axis)

0 commit comments

Comments
 (0)