Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
from keras.src.ops.numpy import heaviside as heaviside
from keras.src.ops.numpy import histogram as histogram
from keras.src.ops.numpy import hstack as hstack
from keras.src.ops.numpy import hypot as hypot
from keras.src.ops.numpy import identity as identity
from keras.src.ops.numpy import imag as imag
from keras.src.ops.numpy import inner as inner
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from keras.src.ops.numpy import heaviside as heaviside
from keras.src.ops.numpy import histogram as histogram
from keras.src.ops.numpy import hstack as hstack
from keras.src.ops.numpy import hypot as hypot
from keras.src.ops.numpy import identity as identity
from keras.src.ops.numpy import imag as imag
from keras.src.ops.numpy import inner as inner
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
from keras.src.ops.numpy import heaviside as heaviside
from keras.src.ops.numpy import histogram as histogram
from keras.src.ops.numpy import hstack as hstack
from keras.src.ops.numpy import hypot as hypot
from keras.src.ops.numpy import identity as identity
from keras.src.ops.numpy import imag as imag
from keras.src.ops.numpy import inner as inner
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from keras.src.ops.numpy import heaviside as heaviside
from keras.src.ops.numpy import histogram as histogram
from keras.src.ops.numpy import hstack as hstack
from keras.src.ops.numpy import hypot as hypot
from keras.src.ops.numpy import identity as identity
from keras.src.ops.numpy import imag as imag
from keras.src.ops.numpy import inner as inner
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def heaviside(x1, x2):
return jnp.heaviside(x1, x2)


def hypot(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.hypot(x1, x2)


def kaiser(x, beta):
x = convert_to_tensor(x)
return jnp.kaiser(x, beta)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,19 @@ def hstack(xs):
return np.hstack(xs)


def hypot(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)

dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = config.floatx()
elif dtype in ["int64"]:
dtype = "float64"

return np.hypot(x1, x2).astype(dtype)


def identity(n, dtype=None):
dtype = dtype or config.floatx()
return np.identity(n, dtype=dtype)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ NumpyDtypeTest::test_blackman
NumpyDtypeTest::test_hamming
NumpyDtypeTest::test_hanning
NumpyDtypeTest::test_heaviside
NumpyDtypeTest::test_hypot
NumpyDtypeTest::test_kaiser
NumpyDtypeTest::test_bitwise
NumpyDtypeTest::test_cbrt
Expand Down Expand Up @@ -142,6 +143,7 @@ NumpyTwoInputOpsCorrectnessTest::test_digitize
NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan
NumpyTwoInputOpsCorrectnessTest::test_einsum
NumpyTwoInputOpsCorrectnessTest::test_heaviside
NumpyTwoInputOpsCorrectnessTest::test_hypot
NumpyTwoInputOpsCorrectnessTest::test_inner
NumpyTwoInputOpsCorrectnessTest::test_isin
NumpyTwoInputOpsCorrectnessTest::test_linspace
Expand All @@ -164,8 +166,10 @@ NumpyOneInputOpsStaticShapeTest::test_cbrt
NumpyOneInputOpsStaticShapeTest::test_isneginf
NumpyOneInputOpsStaticShapeTest::test_isposinf
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
NumpyTwoInputOpsDynamicShapeTest::test_hypot
NumpyTwoInputOpsDynamicShapeTest::test_isin
NumpyTwoInputOpsStaticShapeTest::test_heaviside
NumpyTwoInputOpsStaticShapeTest::test_hypot
NumpyTwoInputOpsStaticShapeTest::test_isin
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
CoreOpsBehaviorTests::test_scan_invalid_arguments
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,10 @@ def hstack(xs):
return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0))


def hypot(x1, x2):
raise NotImplementedError("`hypot` is not supported with openvino backend")


def identity(n, dtype=None):
n = get_ov_output(n)
dtype = Type.f32 if dtype is None else dtype
Expand Down
22 changes: 22 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,28 @@ def hstack(xs):
return tf.concat(xs, axis=1)


def hypot(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)

dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = config.floatx()
elif dtype in ["int64"]:
dtype = "float64"

x1 = tf.cast(x1, dtype)
x2 = tf.cast(x2, dtype)

x1_abs = tf.abs(x1)
x2_abs = tf.abs(x2)
max_val = tf.maximum(x1_abs, x2_abs)
min_val = tf.minimum(x1_abs, x2_abs)

ratio = tf.math.divide_no_nan(min_val, max_val)
return max_val * tf.sqrt(1.0 + tf.square(ratio))


def identity(n, dtype=None):
return eye(N=n, M=n, dtype=dtype)

Expand Down
16 changes: 16 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,22 @@ def hstack(xs):
return torch.hstack(xs)


def hypot(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)

dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = config.floatx()
elif dtype == "int64":
dtype = "float64"

x1 = cast(x1, dtype)
x2 = cast(x2, dtype)

return torch.hypot(x1, x2)


def identity(n, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())

Expand Down
44 changes: 44 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,50 @@ def hstack(xs):
return backend.numpy.hstack(xs)


class Hypot(Operation):
def call(self, x1, x2):
return backend.numpy.hypot(x1, x2)

def compute_output_spec(self, x1, x2):
dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = backend.floatx()
elif dtype == "int64":
dtype = "float64"
return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype)


@keras_export(["keras.ops.hypot", "keras.ops.numpy.hypot"])
def hypot(x1, x2):
"""Element-wise hypotenuse of right triangles with legs `x1` and `x2`.

This is equivalent to computing `sqrt(x1**2 + x2**2)` element-wise,
with shape determined by broadcasting.

Args:
x1: A tensor, representing the first leg of the right triangle.
x2: A tensor, representing the second leg of the right triangle.

Returns:
A tensor with a shape determined by broadcasting `x1` and `x2`.

Example:
>>> x1 = keras.ops.convert_to_tensor([3.0, 4.0, 5.0])
>>> x2 = keras.ops.convert_to_tensor([4.0, 3.0, 12.0])
>>> keras.ops.hypot(x1, x2)
array([5., 5., 13.], dtype=float32)

>>> x1 = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
>>> x2 = keras.ops.convert_to_tensor([1, 1])
>>> keras.ops.hypot(x1, x2)
array([[1.41421356 2.23606798],
[3.16227766 4.12310563]], dtype=float32)
"""
if any_symbolic_tensors((x1, x2)):
return Hypot().symbolic_call(x1, x2)
return backend.numpy.hypot(x1, x2)


@keras_export(["keras.ops.identity", "keras.ops.numpy.identity"])
def identity(n, dtype=None):
"""Return the identity tensor.
Expand Down
42 changes: 42 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def test_heaviside(self):
y = KerasTensor((None, 3))
self.assertEqual(knp.heaviside(x, y).shape, (None, 3))

def test_hypot(self):
x = KerasTensor((None, 3))
y = KerasTensor((None, 3))
self.assertEqual(knp.hypot(x, y).shape, (None, 3))

def test_subtract(self):
x = KerasTensor((None, 3))
y = KerasTensor((2, None))
Expand Down Expand Up @@ -520,6 +525,11 @@ def test_heaviside(self):
y = KerasTensor((1, 3))
self.assertEqual(knp.heaviside(x, y).shape, (2, 3))

def test_hypot(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
self.assertEqual(knp.hypot(x, y).shape, (2, 3))

def test_subtract(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
Expand Down Expand Up @@ -2400,6 +2410,17 @@ def test_heaviside(self):
self.assertAllClose(knp.heaviside(x, y), np.heaviside(x, y))
self.assertAllClose(knp.Heaviside()(x, y), np.heaviside(x, y))

def test_hypot(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [6, 5, 4]])
self.assertAllClose(knp.hypot(x, y), np.hypot(x, y))
self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y))

x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array(4)
self.assertAllClose(knp.hypot(x, y), np.hypot(x, y))
self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y))

def test_subtract(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
Expand Down Expand Up @@ -7501,6 +7522,27 @@ def test_hstack(self, dtypes):
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
def test_hypot(self, dtypes):
import jax.numpy as jnp

dtype1, dtype2 = dtypes
x1 = knp.ones((1, 1), dtype=dtype1)
x2 = knp.ones((1, 1), dtype=dtype2)
x1_jax = jnp.ones((1, 1), dtype=dtype1)
x2_jax = jnp.ones((1, 1), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.hypot(x1_jax, x2_jax).dtype)

self.assertEqual(
standardize_dtype(knp.hypot(x1, x2).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Hypot().symbolic_call(x1, x2).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_identity(self, dtype):
import jax.numpy as jnp
Expand Down