Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,7 +2330,11 @@ def take_along_axis(x, indices, axis=None):
indices = tf.broadcast_to(indices, indices_shape)

# Correct the indices using "fill" mode which is the same as in jax
indices = tf.where(indices < 0, indices + x_shape[static_axis], indices)
indices = tf.where(
indices < 0,
indices + tf.cast(x_shape[static_axis], dtype=indices.dtype),
indices,
)

x = swapaxes(x, static_axis, -1)
indices = swapaxes(indices, static_axis, -1)
Expand Down
10 changes: 6 additions & 4 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8411,14 +8411,16 @@ def test_take(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_take_along_axis(self, dtype):
@parameterized.named_parameters(
named_product(dtype=ALL_DTYPES, indices_dtype=INT_DTYPES)
)
def test_take_along_axis(self, dtype, indices_dtype):
import jax.numpy as jnp

x = knp.ones((1,), dtype=dtype)
indices = knp.zeros((1,), dtype="int32")
indices = knp.zeros((1,), dtype=indices_dtype)
x_jax = jnp.ones((1,), dtype=dtype)
indices_jax = jnp.zeros((1,), dtype="int32")
indices_jax = jnp.zeros((1,), dtype=indices_dtype)
expected_dtype = standardize_dtype(
jnp.take_along_axis(x_jax, indices_jax, 0).dtype
)
Expand Down