Skip to content

Commit 1af93ab

Browse files
author
jax authors
committed
Merge pull request #22288 from vfdev-5:fix-22137-partition-on-unsigned-dtypes
PiperOrigin-RevId: 650201747
2 parents 740945a + 2780b61 commit 1af93ab

File tree

2 files changed

+63
-25
lines changed

2 files changed

+63
-25
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6399,7 +6399,11 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
63996399
kth = _canonicalize_axis(kth, arr.shape[axis])
64006400

64016401
arr = swapaxes(arr, axis, -1)
6402-
bottom = -lax.top_k(-arr, kth + 1)[0]
6402+
if dtypes.isdtype(arr.dtype, "unsigned integer"):
6403+
# Here, we apply a trick to handle correctly 0 values for unsigned integers
6404+
bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1
6405+
else:
6406+
bottom = -lax.top_k(-arr, kth + 1)[0]
64036407
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
64046408
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
64056409
return swapaxes(out, -1, axis)
@@ -6466,7 +6470,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
64666470
kth = _canonicalize_axis(kth, arr.shape[axis])
64676471

64686472
arr = swapaxes(arr, axis, -1)
6469-
bottom_ind = lax.top_k(-arr, kth + 1)[1]
6473+
if dtypes.isdtype(arr.dtype, "unsigned integer"):
6474+
# Here, we apply a trick to handle correctly 0 values for unsigned integers
6475+
bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1]
6476+
else:
6477+
bottom_ind = lax.top_k(-arr, kth + 1)[1]
64706478

64716479
# To avoid issues with duplicate values, we compute the top indices via a proxy
64726480
set_to_zero = lambda a, i: a.at[i].set(0)

tests/lax_numpy_test.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,6 +4254,20 @@ def testArgsortUnstable(self, dtype, shape, axis, descending):
42544254
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
42554255
self.assertArraysEqual(actual, expected)
42564256

4257+
def _assertSamePartionedArrays(self, jnp_output, np_output, axis, kth, shape):
4258+
# Assert that pivot point is equal:
4259+
self.assertArraysEqual(
4260+
lax.index_in_dim(jnp_output, axis=axis, index=kth),
4261+
lax.index_in_dim(np_output, axis=axis, index=kth))
4262+
4263+
# Assert remaining values are correctly partitioned:
4264+
self.assertArraysEqual(
4265+
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4266+
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4267+
self.assertArraysEqual(
4268+
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4269+
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4270+
42574271
@jtu.sample_product(
42584272
[{'shape': shape, 'axis': axis, 'kth': kth}
42594273
for shape in nonzerodim_shapes
@@ -4266,19 +4280,21 @@ def testPartition(self, shape, dtype, axis, kth):
42664280
arg = rng(shape, dtype)
42674281
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
42684282
np_output = np.partition(arg, axis=axis, kth=kth)
4283+
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)
42694284

4270-
# Assert that pivot point is equal:
4271-
self.assertArraysEqual(
4272-
lax.index_in_dim(jnp_output, axis=axis, index=kth),
4273-
lax.index_in_dim(np_output, axis=axis, index=kth))
4274-
4275-
# Assert remaining values are correctly partitioned:
4276-
self.assertArraysEqual(
4277-
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4278-
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4279-
self.assertArraysEqual(
4280-
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4281-
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4285+
@jtu.sample_product(
4286+
kth=range(10),
4287+
dtype=unsigned_dtypes,
4288+
)
4289+
def testPartitionUnsignedWithZeros(self, kth, dtype):
4290+
# https://github.com/google/jax/issues/22137
4291+
max_val = np.iinfo(dtype).max
4292+
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype)
4293+
axis = -1
4294+
shape = arg.shape
4295+
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
4296+
np_output = np.partition(arg, axis=axis, kth=kth)
4297+
self._assertSamePartionedArrays(jnp_output, np_output, axis, kth, shape)
42824298

42834299
@jtu.sample_product(
42844300
[{'shape': shape, 'axis': axis, 'kth': kth}
@@ -4305,19 +4321,33 @@ def testArgpartition(self, shape, dtype, axis, kth):
43054321
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
43064322
jnp_values = getvals(arg, jnp_output)
43074323
np_values = getvals(arg, np_output)
4324+
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)
43084325

4309-
# Assert that pivot point is equal:
4310-
self.assertArraysEqual(
4311-
lax.index_in_dim(jnp_values, axis=axis, index=kth),
4312-
lax.index_in_dim(np_values, axis=axis, index=kth))
4326+
@jtu.sample_product(
4327+
kth=range(10),
4328+
dtype=unsigned_dtypes,
4329+
)
4330+
def testArgpartitionUnsignedWithZeros(self, kth, dtype):
4331+
# https://github.com/google/jax/issues/22137
4332+
max_val = np.iinfo(dtype).max
4333+
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype)
4334+
axis = -1
4335+
shape = arg.shape
4336+
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
4337+
np_output = np.argpartition(arg, axis=axis, kth=kth)
43134338

4314-
# Assert remaining values are correctly partitioned:
4315-
self.assertArraysEqual(
4316-
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
4317-
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
4318-
self.assertArraysEqual(
4319-
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
4320-
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
4339+
# Assert that all indices are present
4340+
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
4341+
4342+
# Because JAX & numpy may treat duplicates differently, we must compare values
4343+
# rather than indices.
4344+
getvals = lambda x, ind: x[ind]
4345+
for ax in range(arg.ndim):
4346+
if ax != range(arg.ndim)[axis]:
4347+
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
4348+
jnp_values = getvals(arg, jnp_output)
4349+
np_values = getvals(arg, np_output)
4350+
self._assertSamePartionedArrays(jnp_values, np_values, axis, kth, shape)
43214351

43224352
@jtu.sample_product(
43234353
[dict(shifts=shifts, axis=axis)

0 commit comments

Comments
 (0)