@@ -4254,6 +4254,20 @@ def testArgsortUnstable(self, dtype, shape, axis, descending):
42544254 actual = jnp .take_along_axis (x , indices , axis = - 1 if axis is NO_VALUE else axis )
42554255 self .assertArraysEqual (actual , expected )
42564256
4257+ def _assertSamePartionedArrays (self , jnp_output , np_output , axis , kth , shape ):
4258+ # Assert that pivot point is equal:
4259+ self .assertArraysEqual (
4260+ lax .index_in_dim (jnp_output , axis = axis , index = kth ),
4261+ lax .index_in_dim (np_output , axis = axis , index = kth ))
4262+
4263+ # Assert remaining values are correctly partitioned:
4264+ self .assertArraysEqual (
4265+ lax .sort (lax .slice_in_dim (jnp_output , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ),
4266+ lax .sort (lax .slice_in_dim (np_output , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ))
4267+ self .assertArraysEqual (
4268+ lax .sort (lax .slice_in_dim (jnp_output , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ),
4269+ lax .sort (lax .slice_in_dim (np_output , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ))
4270+
42574271 @jtu .sample_product (
42584272 [{'shape' : shape , 'axis' : axis , 'kth' : kth }
42594273 for shape in nonzerodim_shapes
@@ -4266,19 +4280,21 @@ def testPartition(self, shape, dtype, axis, kth):
42664280 arg = rng (shape , dtype )
42674281 jnp_output = jnp .partition (arg , axis = axis , kth = kth )
42684282 np_output = np .partition (arg , axis = axis , kth = kth )
4283+ self ._assertSamePartionedArrays (jnp_output , np_output , axis , kth , shape )
42694284
4270- # Assert that pivot point is equal:
4271- self .assertArraysEqual (
4272- lax .index_in_dim (jnp_output , axis = axis , index = kth ),
4273- lax .index_in_dim (np_output , axis = axis , index = kth ))
4274-
4275- # Assert remaining values are correctly partitioned:
4276- self .assertArraysEqual (
4277- lax .sort (lax .slice_in_dim (jnp_output , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ),
4278- lax .sort (lax .slice_in_dim (np_output , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ))
4279- self .assertArraysEqual (
4280- lax .sort (lax .slice_in_dim (jnp_output , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ),
4281- lax .sort (lax .slice_in_dim (np_output , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ))
4285+ @jtu .sample_product (
4286+ kth = range (10 ),
4287+ dtype = unsigned_dtypes ,
4288+ )
4289+ def testPartitionUnsignedWithZeros (self , kth , dtype ):
4290+ # https://github.com/google/jax/issues/22137
4291+ max_val = np .iinfo (dtype ).max
4292+ arg = jnp .array ([[6 , max_val , 0 , 4 , 3 , 1 , 0 , 7 , 5 , 2 ]], dtype = dtype )
4293+ axis = - 1
4294+ shape = arg .shape
4295+ jnp_output = jnp .partition (arg , axis = axis , kth = kth )
4296+ np_output = np .partition (arg , axis = axis , kth = kth )
4297+ self ._assertSamePartionedArrays (jnp_output , np_output , axis , kth , shape )
42824298
42834299 @jtu .sample_product (
42844300 [{'shape' : shape , 'axis' : axis , 'kth' : kth }
@@ -4305,19 +4321,33 @@ def testArgpartition(self, shape, dtype, axis, kth):
43054321 getvals = jax .vmap (getvals , in_axes = ax , out_axes = ax )
43064322 jnp_values = getvals (arg , jnp_output )
43074323 np_values = getvals (arg , np_output )
4324+ self ._assertSamePartionedArrays (jnp_values , np_values , axis , kth , shape )
43084325
4309- # Assert that pivot point is equal:
4310- self .assertArraysEqual (
4311- lax .index_in_dim (jnp_values , axis = axis , index = kth ),
4312- lax .index_in_dim (np_values , axis = axis , index = kth ))
4326+ @jtu .sample_product (
4327+ kth = range (10 ),
4328+ dtype = unsigned_dtypes ,
4329+ )
4330+ def testArgpartitionUnsignedWithZeros (self , kth , dtype ):
4331+ # https://github.com/google/jax/issues/22137
4332+ max_val = np .iinfo (dtype ).max
4333+ arg = jnp .array ([[6 , max_val , 0 , 4 , 3 , 1 , 0 , 7 , 5 , 2 , 3 ]], dtype = dtype )
4334+ axis = - 1
4335+ shape = arg .shape
4336+ jnp_output = jnp .argpartition (arg , axis = axis , kth = kth )
4337+ np_output = np .argpartition (arg , axis = axis , kth = kth )
43134338
4314- # Assert remaining values are correctly partitioned:
4315- self .assertArraysEqual (
4316- lax .sort (lax .slice_in_dim (jnp_values , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ),
4317- lax .sort (lax .slice_in_dim (np_values , start_index = 0 , limit_index = kth , axis = axis ), dimension = axis ))
4318- self .assertArraysEqual (
4319- lax .sort (lax .slice_in_dim (jnp_values , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ),
4320- lax .sort (lax .slice_in_dim (np_values , start_index = kth + 1 , limit_index = shape [axis ], axis = axis ), dimension = axis ))
4339+ # Assert that all indices are present
4340+ self .assertArraysEqual (jnp .sort (jnp_output , axis ), np .sort (np_output , axis ), check_dtypes = False )
4341+
4342+ # Because JAX & numpy may treat duplicates differently, we must compare values
4343+ # rather than indices.
4344+ getvals = lambda x , ind : x [ind ]
4345+ for ax in range (arg .ndim ):
4346+ if ax != range (arg .ndim )[axis ]:
4347+ getvals = jax .vmap (getvals , in_axes = ax , out_axes = ax )
4348+ jnp_values = getvals (arg , jnp_output )
4349+ np_values = getvals (arg , np_output )
4350+ self ._assertSamePartionedArrays (jnp_values , np_values , axis , kth , shape )
43214351
43224352 @jtu .sample_product (
43234353 [dict (shifts = shifts , axis = axis )
0 commit comments