Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
from keras.src.activations.activations import hard_tanh
from keras.src.activations.activations import leaky_relu
from keras.src.activations.activations import linear
from keras.src.activations.activations import log_softmax
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
from keras.src.ops.nn import hard_tanh
from keras.src.ops.nn import leaky_relu
from keras.src.ops.nn import log_sigmoid
from keras.src.ops.nn import log_softmax
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
from keras.src.ops.nn import hard_tanh
from keras.src.ops.nn import leaky_relu
from keras.src.ops.nn import log_sigmoid
from keras.src.ops.nn import log_softmax
Expand Down
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_silu as hard_swish
from keras.src.activations.activations import hard_tanh
from keras.src.activations.activations import leaky_relu
from keras.src.activations.activations import linear
from keras.src.activations.activations import log_softmax
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
from keras.src.ops.nn import hard_tanh
from keras.src.ops.nn import leaky_relu
from keras.src.ops.nn import log_sigmoid
from keras.src.ops.nn import log_softmax
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras.src.ops.nn import hard_sigmoid
from keras.src.ops.nn import hard_silu
from keras.src.ops.nn import hard_silu as hard_swish
from keras.src.ops.nn import hard_tanh
from keras.src.ops.nn import leaky_relu
from keras.src.ops.nn import log_sigmoid
from keras.src.ops.nn import log_softmax
Expand Down
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.activations.activations import glu
from keras.src.activations.activations import hard_sigmoid
from keras.src.activations.activations import hard_silu
from keras.src.activations.activations import hard_tanh
from keras.src.activations.activations import leaky_relu
from keras.src.activations.activations import linear
from keras.src.activations.activations import log_softmax
Expand Down Expand Up @@ -42,6 +43,7 @@
exponential,
hard_sigmoid,
hard_silu,
hard_tanh,
linear,
mish,
log_softmax,
Expand Down
15 changes: 15 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,21 @@ def tanh(x):
return ops.tanh(x)


@keras_export("keras.activations.hard_tanh")
def hard_tanh(x):
"""HardTanh activation function.

It is defined as:
`hard_tanh(x) = -1 for x < -1`,
`hard_tanh(x) = x for -1 <= x <= 1`,
`hard_tanh(x) = 1 for x > 1`.

Args:
x: Input tensor.
"""
return ops.hard_tanh(x)


@keras_export("keras.activations.sigmoid")
def sigmoid(x):
"""Sigmoid activation function.
Expand Down
9 changes: 9 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,15 @@ def glu(x, axis=-1):
expected = glu(x, axis=-2)
self.assertAllClose(result, expected, rtol=1e-05)

def test_hard_tanh(self):
def hard_tanh(x):
return np.clip(x, -1.0, 1.0)

x = np.random.random((2, 5))
result = activations.hard_tanh(x[np.newaxis, :])[0]
expected = hard_tanh(x)
self.assertAllClose(result, expected, rtol=1e-05)

def test_elu(self):
x = np.random.random((2, 5))
result = activations.elu(x[np.newaxis, :])[0]
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def glu(x, axis=-1):
return jnn.glu(x, axis=axis)


def hard_tanh(x):
x = convert_to_tensor(x)
return jnn.hard_tanh(x)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
return jnn.softmax(x, axis=axis)
Expand Down
7 changes: 7 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def glu(x, axis=-1):
return x1 * (1 / (1 + np.exp(-x2)))


def hard_tanh(x):
x = convert_to_tensor(x)
min_val = np.asarray(-1.0, x.dtype)
max_val = np.asarray(1.0, x.dtype)
return np.array(np.clip(x, min_val, max_val), dtype=x.dtype)


def softmax(x, axis=None):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def glu(x, axis=-1):
return x1 * tf.sigmoid(x2)


def hard_tanh(x):
return tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0)


def softmax(x, axis=-1):
logits = x
if axis is None:
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def glu(x, axis=-1):
return tnn.glu(x, dim=axis)


def hard_tanh(x):
x = convert_to_tensor(x)
return tnn.hardtanh(x, min_val=-1.0, max_val=1.0)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
dtype = backend.standardize_dtype(x.dtype)
Expand Down
39 changes: 39 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,45 @@ def glu(x, axis=-1):
return backend.nn.glu(x, axis=axis)


class HardTanh(Operation):
def __init__(self):
super().__init__()

def call(self, x):
return backend.nn.hard_tanh(x)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.hard_tanh", "keras.ops.nn.hard_tanh"])
def hard_tanh(x):
"""Applies the HardTanh function element-wise.

It is defined as:

`f(x) = -1 for x < -1`, `f(x) = x for -1 <= x <= 1`, `f(x) = 1 for x > 1`.

Args:
x: Input tensor.

Returns:
Output tensor of same shape as `x`
where values are clamped between -1 and 1.

Example:

>>> x = x = np.array([-2., -1., 0., 1., 2.])
>>> x_hard_tanh = keras.ops.hard_tanh(x)
>>> print(x_hard_tanh)
array([-1. -1. 0. 1. 1.], shape=(5,), dtype=float64)

"""
if any_symbolic_tensors((x,)):
return HardTanh().symbolic_call(x)
return backend.nn.hard_tanh(x)


class Softmax(Operation):
def __init__(self, axis=-1):
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def test_glu(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.glu(x).shape, (None, 2, 3))

def test_hard_tanh(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.hard_tanh(x).shape, (None, 2, 3))

def test_softmax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.softmax(x).shape, (None, 2, 3))
Expand Down Expand Up @@ -802,6 +806,10 @@ def test_glu(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.glu(x).shape, (1, 2, 3))

def test_hard_tanh(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.hard_tanh(x).shape, (1, 2, 3))

def test_softmax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.softmax(x).shape, (1, 2, 3))
Expand Down Expand Up @@ -1322,6 +1330,13 @@ def test_glu(self):
[-0.8807971, 0.0, 0.98201376],
)

def test_hard_tanh(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(
knn.hard_tanh(x),
[-1.0, 0.0, 1.0, 1.0, 1.0],
)

def test_softmax(self):
x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllClose(
Expand Down Expand Up @@ -2411,6 +2426,24 @@ def test_celu(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_hard_tanh(self, dtype):
import jax.nn as jnn
import jax.numpy as jnp

x = knp.ones((), dtype=dtype)
x_jax = jnp.ones((), dtype=dtype)
expected_dtype = standardize_dtype(jnn.hard_tanh(x_jax).dtype)

self.assertEqual(
standardize_dtype(knn.hard_tanh(x).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knn.HardTanh().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
def test_glu(self, dtype):
import jax.nn as jnn
Expand Down