Skip to content

Commit eecd34f

Browse files
authored
Fix: keras.ops.quantile works with tf graph execution (#21782)
A `deque` failing to be converted to a tensor in the transpose call in `keras.ops.quantile` caused errors when running in tf graph contexts. By constructing a list from the deque before passing it to the `transpose` call we avoid this error.
1 parent c2bc6cf commit eecd34f

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

keras/src/backend/tensorflow/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2301,7 +2301,7 @@ def _get_indices(method):
23012301
return gathered_y
23022302
perm = collections.deque(range(ndims))
23032303
perm.rotate(shift_value_static)
2304-
return tf.transpose(a=gathered_y, perm=perm)
2304+
return tf.transpose(a=gathered_y, perm=list(perm))
23052305

23062306

23072307
def quantile(x, q, axis=None, method="linear", keepdims=False):

keras/src/ops/numpy_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3325,6 +3325,24 @@ def test_quantile(self):
33253325
np.quantile(x, q, axis=1, method=method),
33263326
)
33273327

3328+
@pytest.mark.skipif(
3329+
backend.backend() != "tensorflow",
3330+
reason="Only test tensorflow backend",
3331+
)
3332+
def test_quantile_in_tf_function(self):
3333+
import tensorflow as tf
3334+
3335+
x = knp.array([[1, 2, 3], [4, 5, 6]])
3336+
q = [0.5]
3337+
expected_output = np.array([[2, 5]])
3338+
3339+
@tf.function
3340+
def run_quantile(x, q, axis):
3341+
return knp.quantile(x, q, axis=axis)
3342+
3343+
result = run_quantile(x, q, axis=1)
3344+
self.assertAllClose(result, expected_output)
3345+
33283346
def test_take(self):
33293347
x = np.arange(24).reshape([1, 2, 3, 4])
33303348
indices = np.array([0, 1])

0 commit comments

Comments
 (0)