diff --git a/lib/axon.ex b/lib/axon.ex index 740bae5a..0beac914 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -615,7 +615,7 @@ defmodule Axon do end @activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++ - [:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++ + [:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++ [:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh] @doc """ diff --git a/lib/axon/activations.ex b/lib/axon/activations.ex index 8f6f0a20..82b15768 100644 --- a/lib/axon/activations.ex +++ b/lib/axon/activations.ex @@ -377,6 +377,32 @@ defmodule Axon.Activations do """ defn log_sigmoid(x), do: -softplus(-x) + @doc ~S""" + Mish activation. + + $$f(x_i) = x_i* \tanh(\log(1 + e^x_i))$$ + + ## Examples + + iex> Axon.Activations.mish(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data])) + #Nx.Tensor< + f32[data: 7] + [-0.14564745128154755, -0.2525014877319336, -0.30340147018432617, 0.0, 0.8650984168052673, 1.9439589977264404, 2.98653507232666] + > + + iex> Axon.Activations.mish(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data])) + #Nx.Tensor< + bf16[batch: 2][data: 3] + [ + [-0.30078125, -0.25, -0.1435546875], + [0.86328125, 1.9375, 2.96875] + ] + > + """ + defn mish(x) do + x * tanh(softplus(x)) + end + @doc ~S""" Rectified linear unit activation. diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 69727612..4450b83e 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -217,7 +217,7 @@ defmodule Axon.Compiler do ## Activation Layers @activation_layers [:celu, :elu, :exp, :gelu, :hard_sigmoid, :hard_silu, :hard_tanh] ++ - [:leaky_relu, :linear, :log_sigmoid, :relu, :relu6] ++ + [:leaky_relu, :linear, :log_sigmoid, :mish, :relu, :relu6] ++ [:sigmoid, :silu, :selu, :softmax, :softplus, :softsign, :tanh] defp recur_predict_fun(%Axon{op: op, parent: parent}, cache, param_map, input_map)