-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(activations): Implement all activation functions in TFA
- Loading branch information
1 parent
e7f2d7b
commit 23c01ea
Showing
9 changed files
with
291 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,20 @@ | ||
# activations/__init__.py | ||
|
||
__all__ = [ | ||
"glu", "reglu", "geglu", "swiglu", "seglu", | ||
"sparsemax", "differentiable_binary" | ||
"glu", "reglu", "geglu", "swiglu", "seglu", "rrelu", | ||
"hardshrink", "softshrink", "tanhshrink", | ||
"lisht", "mish", "snake", "sparsemax", "differentiable_binary" | ||
] | ||
|
||
from kca.activations.glu import glu, reglu, geglu, swiglu, seglu | ||
from kca.activations.sparsemax import sparsemax | ||
from kca.activations.differentiable_binary import differentiable_binary | ||
from .glu import glu, reglu, geglu, swiglu, seglu | ||
from .rrelu import rrelu | ||
|
||
from .hardshrink import hardshrink | ||
from .softshrink import softshrink | ||
from .tanhshrink import tanhshrink | ||
|
||
from .lisht import lisht | ||
from .mish import mish | ||
from .snake import snake | ||
from .sparsemax import sparsemax | ||
from .differentiable_binary import differentiable_binary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from keras_core import ops | ||
|
||
from kca.utils.types import TensorLike, Number | ||
|
||
|
||
def hardshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> TensorLike: | ||
r"""Hard shrink function. | ||
Computes hard shrink function: | ||
$$ | ||
\mathrm{hardshrink}(x) = | ||
\begin{cases} | ||
x & \text{if } x < \text{lower} \\ | ||
x & \text{if } x > \text{upper} \\ | ||
0 & \text{otherwise} | ||
\end{cases}. | ||
$$ | ||
Usage: | ||
>>> x = tf.constant([1.0, 0.0, 1.0]) | ||
>>> kca.activations.hardshrink(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 0., 1.], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
lower: `float`, lower bound for setting values to zeros. | ||
upper: `float`, upper bound for setting values to zeros. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
if lower > upper: | ||
raise ValueError( | ||
"The value of lower is {} and should not be higher than the value variable upper, which is {} .".format( | ||
lower, upper) | ||
) | ||
mask_lower = x < lower | ||
mask_upper = upper < x | ||
mask = ops.logical_or(mask_lower, mask_upper) | ||
mask = ops.cast(mask, x.dtype) | ||
return x * mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from keras_core import activations | ||
from kca.utils.types import TensorLike | ||
|
||
|
||
def lisht(x: TensorLike) -> TensorLike: | ||
r"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. | ||
Computes linearly scaled hyperbolic tangent (LiSHT): | ||
$$ | ||
\mathrm{lisht}(x) = x * \tanh(x). | ||
$$ | ||
See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894). | ||
Usage: | ||
>>> x = tf.constant([1.0, 0.0, 1.0]) | ||
>>> kca.activations.lisht(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.7615942, 0. , 0.7615942], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
return x * activations.tanh(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from keras_core import activations | ||
from kca.utils.types import TensorLike | ||
|
||
|
||
def mish(x: TensorLike) -> TensorLike: | ||
r"""Mish: A Self Regularized Non-Monotonic Neural Activation Function. | ||
Computes mish activation: | ||
$$ | ||
\mathrm{mish}(x) = x \cdot \tanh(\mathrm{softplus}(x)). | ||
$$ | ||
See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). | ||
Usage: | ||
>>> x = tf.constant([1.0, 0.0, 1.0]) | ||
>>> kca.activations.mish(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.865098..., 0. , 0.865098...], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
return x * activations.tanh(activations.softplus(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from keras_core import ops, random | ||
from typing import Optional | ||
from kca.utils.types import TensorLike, Number, Generator | ||
|
||
|
||
def rrelu(x: TensorLike, lower: Number = 0.125, upper: Number = 0.3333333333333333, training: bool = False, | ||
seed: Optional[int] = None, rng: Optional[Generator] = None) -> TensorLike: | ||
r"""Randomized leaky rectified liner unit function. | ||
Computes rrelu function: | ||
$$ | ||
\mathrm{rrelu}(x) = | ||
\begin{cases} | ||
x & \text{if } x > 0 \\ | ||
a x | ||
\end{cases}, | ||
$$ | ||
where | ||
$$ | ||
a \sim \mathcal{U}(\mathrm{lower}, \mathrm{upper}) | ||
$$ | ||
when `training` is `True`; or | ||
$$ | ||
a = \frac{\mathrm{lower} + \mathrm{upper}}{2} | ||
$$ | ||
when `training` is `False`. | ||
See [Empirical Evaluation of Rectified Activations in Convolutional Network](https://arxiv.org/abs/1505.00853). | ||
Usage: | ||
>>> x = tf.constant([-1.0, 0.0, 1.0]) | ||
>>> kca.activations.rrelu(x, training=False) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.22916667, 0. , 1. ], dtype=float32)> | ||
>>> kca.activations.rrelu(x, training=True, seed=2020) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.22631127, 0. , 1. ], dtype=float32)> | ||
>>> generator = tf.random.Generator.from_seed(2021) | ||
>>> kca.activations.rrelu(x, training=True, rng=generator) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.16031083, 0. , 1. ], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
lower: `float`, lower bound for random alpha. | ||
upper: `float`, upper bound for random alpha. | ||
training: `bool`, indicating whether the `call` | ||
is meant for training or inference. | ||
seed: `int`, this sets the operation-level seed. | ||
rng: A `tf.random.Generator`. | ||
Returns: | ||
result: A `Tensor`. Has the same type as `x`. | ||
""" | ||
lower = ops.cast(lower, x.dtype) | ||
upper = ops.cast(upper, x.dtype) | ||
|
||
def random_a(): | ||
if rng is not None and seed is not None: | ||
raise ValueError("Either seed or rng should be specified. Not both at the same time.") | ||
|
||
if rng is not None: | ||
return rng.uniform(ops.shape(x), minval=lower, maxval=upper, dtype=x.dtype) | ||
|
||
return random.uniform( | ||
ops.shape(x), minval=lower, maxval=upper, dtype=x.dtype, seed=seed | ||
) | ||
|
||
a = random_a() if training else ((lower + upper) / 2) | ||
|
||
return ops.where(x >= 0, x, a * x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from keras_core import ops, activations | ||
from kca.utils.types import TensorLike, Number | ||
|
||
|
||
def snake(x: TensorLike, frequency: Number = 1) -> TensorLike: | ||
r"""Snake activation to learn periodic functions. | ||
Computes snake activation: | ||
$$ | ||
\mathrm{snake}(x) = \mathrm{x} + \frac{1 - \cos(2 \cdot \mathrm{frequency} \cdot x)}{2 \cdot \mathrm{frequency}}. | ||
$$ | ||
See [Neural Networks Fail to Learn Periodic Functions and How to Fix It](https://arxiv.org/abs/2006.08195). | ||
Usage: | ||
>>> x = tf.constant([-1.0, 0.0, 1.0]) | ||
>>> tfa.activations.snake(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.29192656, 0. , 1.7080734 ], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. | ||
frequency: A scalar, frequency of the periodic part. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
frequency = ops.cast(frequency, x.dtype) | ||
|
||
return x + (1 - ops.cos(2 * frequency * x)) / (2 * frequency) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from keras_core import ops | ||
|
||
from kca.utils.types import TensorLike, Number | ||
|
||
|
||
def softshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> TensorLike: | ||
r"""Soft shrink function. | ||
Computes soft shrink function: | ||
$$ | ||
\mathrm{softshrink}(x) = | ||
\begin{cases} | ||
x - \mathrm{lower} & \text{if } x < \mathrm{lower} \\ | ||
x - \mathrm{upper} & \text{if } x > \mathrm{upper} \\ | ||
0 & \text{otherwise} | ||
\end{cases}. | ||
$$ | ||
Usage: | ||
>>> x = tf.constant([-1.0, 0.0, 1.0]) | ||
>>> kca.activations.softshrink(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.5, 0. , 0.5], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
lower: `float`, lower bound for setting values to zeros. | ||
upper: `float`, upper bound for setting values to zeros. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
if lower > upper: | ||
raise ValueError( | ||
"The value of lower is {} and should not be higher than the value variable upper, which is {} .".format( | ||
lower, upper) | ||
) | ||
values_below_lower = ops.where(x < lower, x - lower, 0) | ||
values_above_upper = ops.where(upper < x, x - upper, 0) | ||
return values_below_lower + values_above_upper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from keras_core import activations | ||
from kca.utils.types import TensorLike | ||
|
||
|
||
def tanhshrink(x: TensorLike) -> TensorLike: | ||
r"""Tanh shrink function. | ||
Applies the element-wise function: | ||
$$ | ||
\mathrm{tanhshrink}(x) = x - \tanh(x). | ||
$$ | ||
Usage: | ||
>>> x = tf.constant([-1.0, 0.0, 1.0]) | ||
>>> kca.activations.tanhshrink(x) | ||
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.23840582, 0. , 0.23840582], dtype=float32)> | ||
Args: | ||
x: A `Tensor`. Must be one of the following types: | ||
`bfloat16`, `float16`, `float32`, `float64`. | ||
Returns: | ||
A `Tensor`. Has the same type as `x`. | ||
""" | ||
return x - activations.tanh(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,4 +61,8 @@ | |
OptimizerType = Union[ | ||
keras.optimizers.Optimizer, | ||
str | ||
] | ||
] | ||
|
||
Generator = Union[ | ||
tf.random.Generator | ||
] |