Skip to content

Updated model_utils.py with Telu #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ngclearn/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def create_function(fun_name, args=None):
elif fun_name == "gelu":
fx = gelu
dfx = d_gelu
elif fun_name == "telu":
fx = telu
dfx = d_telu
elif fun_name == "softplus":
fx = softplus
dfx = d_softplus
Expand Down Expand Up @@ -294,6 +297,35 @@ def d_relu(x):
return (x >= 0.).astype(jnp.float32)

@jit
def telu(x):
"""
Proposed by Fernandez and Mali 24, https://arxiv.org/abs/2412.20269 and https://arxiv.org/abs/2402.02790
TeLU activation: f(x) = x * tanh(e^x)

Args:
x: input (tensor) value

Returns:
output (tensor) value
"""
return x * jnp.tanh(jnp.exp(x))

@jit
def d_telu(x):
"""

Derivative of TeLU: f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))

Args:
x: input (tensor) value

Returns:
output (tensor) derivative value (with respect to input)
"""
ex = jnp.exp(x)
tanh_ex = jnp.tanh(ex)
return tanh_ex + x * ex * (1.0 - tanh_ex ** 2)
@jit
def sine(x, omega_0=30):
"""
f(x) = sin(x * omega_0).
Expand Down