Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,9 +2260,7 @@ def fix_negative_indices(i):


def take_along_axis(x, indices, axis=None):
from keras.src.ops.operation_utils import (
compute_take_along_axis_output_shape,
)
from keras.src.ops import operation_utils

x = convert_to_tensor(x)
indices = convert_to_tensor(indices, "int64")
Expand All @@ -2276,28 +2274,46 @@ def take_along_axis(x, indices, axis=None):

# Compute the static output shape as later on, all shapes manipulations
# use dynamic shapes.
static_output_shape = compute_take_along_axis_output_shape(
static_output_shape = operation_utils.compute_take_along_axis_output_shape(
x.shape, indices.shape, axis
)
rank = x.ndim
static_axis = axis
axis = axis + rank if axis < 0 else axis

# Broadcast shapes to match, ensure that the axis of interest is not
# broadcast.
x_shape_original = tf.shape(x, out_type=indices.dtype)
indices_shape_original = tf.shape(indices, out_type=indices.dtype)
x_shape = tf.tensor_scatter_nd_update(x_shape_original, [[axis]], [1])
indices_shape = tf.tensor_scatter_nd_update(
indices_shape_original, [[axis]], [1]
)
broadcasted_shape = tf.broadcast_dynamic_shape(x_shape, indices_shape)
x_shape = tf.tensor_scatter_nd_update(
broadcasted_shape, [[axis]], [x_shape_original[axis]]
)
indices_shape = tf.tensor_scatter_nd_update(
broadcasted_shape, [[axis]], [indices_shape_original[axis]]
if axis >= rank:
raise ValueError(f"Invalid axis: {static_axis} for input rank: {rank}")

x_original_shape = shape_op(x)
indices_original_shape = shape_op(indices)

# Broadcast the static shapes first, but not for the `axis` dimension.
x_static_shape = list(x.shape)
indices_static_shape = list(indices.shape)
x_static_shape[axis] = 1
indices_static_shape[axis] = 1
broadcast_shape = operation_utils.broadcast_shapes(
x_static_shape, indices_static_shape
)

if None in broadcast_shape:
# Dynamic broadcast case. Note that `tf.broadcast_dynamic_shape` is
# not always XLA compilable with dynamic dimensions.
# We replace `None`s with the dynamic dimensions.
# `maximum` is the correct formula only when shapes are broadcastable,
# we rely on the broacast itself to fail in the incorrect case rather
# than make some expensive dynamic checks here.
broadcast_shape = [
tf.maximum(x_original_shape[i], indices_original_shape[i])
if dim is None
else dim
for i, dim in enumerate(broadcast_shape)
]

x_shape = list(broadcast_shape)
x_shape[axis] = x_original_shape[axis]
indices_shape = list(broadcast_shape)
indices_shape[axis] = indices_original_shape[axis]
x = tf.broadcast_to(x, x_shape)
indices = tf.broadcast_to(indices, indices_shape)

Expand Down