-
Couldn't load subscription status.
- Fork 19.6k
Closed
Labels
stat:awaiting keras-engAwaiting response from Keras engineerAwaiting response from Keras engineertype:Bug
Description
Description:
When using keras.backend.argmax with an input array containing -0.0, the result is incorrect. Specifically, the function returns 1 (the index of -0.0) as the position of the maximum value, while the actual maximum value is 1.401298464324817e-45 at index 2.
This issue is reproducible in TensorFlow and JAX as well, as they share similar backend logic for the argmax function. However, PyTorch correctly returns the expected index 2 for the maximum value.
Expected Behavior:
keras.backend.argmax should return 2, as the value at index 2 (1.401298464324817e-45) is greater than both -1.0 and -0.0.
import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp
from tensorflow import keras
def test_argmax():
# Input data
input_data = np.array([-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32)
# PyTorch argmax
pytorch_result = torch.argmax(torch.tensor(input_data, dtype=torch.float32)).item()
print(f"PyTorch argmax result: {pytorch_result}")
# TensorFlow argmax
tensorflow_result = tf.math.argmax(input_data).numpy()
print(f"TensorFlow argmax result: {tensorflow_result}")
# Keras argmax (Keras internally uses TensorFlow, so should be the same)
keras_result = keras.backend.argmax(input_data).numpy()
print(f"Keras argmax result: {keras_result}")
# JAX argmax
jax_result = jnp.argmax(input_data)
print(f"JAX argmax result: {jax_result}")
if __name__ == "__main__":
test_argmax()
PyTorch argmax result: 2
TensorFlow argmax result: 1
Keras argmax result: 1
JAX argmax result: 1
harshaljanjani
Metadata
Metadata
Assignees
Labels
stat:awaiting keras-engAwaiting response from Keras engineerAwaiting response from Keras engineertype:Bug