Skip to content

argmax returns incorrect result for input containing -0.0 (Keras using TensorFlow backend) #20350

@LilyDong0127

Description

@LilyDong0127

Description:
When using keras.backend.argmax with an input array containing -0.0, the result is incorrect. Specifically, the function returns 1 (the index of -0.0) as the position of the maximum value, while the actual maximum value is 1.401298464324817e-45 at index 2.

This issue is reproducible in TensorFlow and JAX as well, as they share similar backend logic for the argmax function. However, PyTorch correctly returns the expected index 2 for the maximum value.

Expected Behavior:
keras.backend.argmax should return 2, as the value at index 2 (1.401298464324817e-45) is greater than both -1.0 and -0.0.

import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp
from tensorflow import keras

def test_argmax():
    # Input data
    input_data = np.array([-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32)

    # PyTorch argmax
    pytorch_result = torch.argmax(torch.tensor(input_data, dtype=torch.float32)).item()
    print(f"PyTorch argmax result: {pytorch_result}")

    # TensorFlow argmax
    tensorflow_result = tf.math.argmax(input_data).numpy()
    print(f"TensorFlow argmax result: {tensorflow_result}")

    # Keras argmax (Keras internally uses TensorFlow, so should be the same)
    keras_result = keras.backend.argmax(input_data).numpy()
    print(f"Keras argmax result: {keras_result}")

    # JAX argmax
    jax_result = jnp.argmax(input_data)
    print(f"JAX argmax result: {jax_result}")

if __name__ == "__main__":
    test_argmax()

PyTorch argmax result: 2
TensorFlow argmax result: 1
Keras argmax result: 1
JAX argmax result: 1

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions