Skip to content

Commit

Permalink
add exp_tanh_gelu module (#4751)
Browse files Browse the repository at this point in the history
* add exp_tanh_gelu module

* fix comment

* fix comment

* fix comment

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
BBuf and oneflow-ci-bot authored Apr 28, 2021
1 parent dfcd1d7 commit 4b876e4
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
98 changes: 98 additions & 0 deletions oneflow/python/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import oneflow._oneflow_internal
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op


@oneflow_export("nn.Sigmoid")
Expand All @@ -39,3 +40,100 @@ def __init__(self):
def forward(self, x):
res = self._op(x)[0]
return res


@oneflow_export("nn.Tanh")
class Tanh(Module):
r"""This operator computes the hyperbolic tangent value of Tensor.
The equation is:
.. math::
out = \frac{e^x-e^{-x}}{e^x+e^{-x}}
Args:
x (oneflow.Tensor): A Tensor
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
oneflow.Tensor: The result Tensor
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
x = np.array([-1, 0, 1]).astype(np.float32)
input = flow.Tensor(x)
tanh = flow.nn.Tanh()
out = tanh(input).numpy()
# out [-0.7615942 0. 0.7615942]
"""

def __init__(self):
super().__init__()
self._op = flow.builtin_op("tanh").Input("x").Output("y").Build()

def forward(self, x):
res = self._op(x)[0]
return res


@oneflow_export("tanh")
@register_tensor_op("tanh")
def tanh_op(tensor):
return Tanh()(tensor)


@oneflow_export("nn.GELU")
class GELU(Module):
r"""Gelu activation operator.
The equation is:
.. math::
out = 0.5 * x * (1 + tanh(\sqrt{\frac{2}{\pi}} * (x + 0.044715x^{3})))
Args:
x (oneflow.Tensor): Input Tensor
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
oneflow.Tensor: A Tensor.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
import oneflow.typing as tp
x = np.array([-0.5, 0, 0.5]).astype(np.float32)
input = flow.Tensor(x)
gelu = flow.nn.GELU()
out = gelu(input)
# out [-0.15426877, 0., 0.34573123]
"""

def __init__(self):
super().__init__()
self._op = flow.builtin_op("gelu").Input("in").Output("out").Build()

def forward(self, x):
res = self._op(x)[0]
return res


@oneflow_export("gelu")
@register_tensor_op("gelu")
def gelu_op(tensor):
return GELU()(tensor)
62 changes: 62 additions & 0 deletions oneflow/python/nn/modules/exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op


class Exp(Module):
"""This operator computes the exponential of Tensor.
The equation is:
.. math::
out = e^x
Args:
x (oneflow.Tensor): A Tensor
Returns:
oneflow.Tensor: The result Tensor
For example:
.. code-block:: python
import numpy as np
import oneflow as flow
x = flow.Tensor(np.array([1, 2, 3]).astype(np.float32))
y = x.exp().numpy()
# y [ 2.7182817 7.389056 20.085537 ]
"""

def __init__(self) -> None:
super().__init__()
self._op = flow.builtin_op("exp").Input("x").Output("y").Build()

def forward(self, x):
return self._op(x)[0]


@oneflow_export("exp")
@register_tensor_op("exp")
def exp_op(tensor):
return Exp()(tensor)
108 changes: 108 additions & 0 deletions oneflow/python/test/modules/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTanhModule(flow.unittest.TestCase):
def _test_body_tanh(test_case, input_arr):
x = flow.Tensor(input_arr)

tanh = flow.nn.Tanh()
y = tanh(x)
z = np.tanh(input_arr)

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))

def _test_ones_body_tanh(self, shape):
x = np.ones(shape, dtype=np.float32)
self._test_body_tanh(x)

def _test_random_body_tanh(self, shape):
x = np.random.random(shape).astype(np.float32)
self._test_body_tanh(x)

def test_ones_input_tanh(self):
self._test_ones_body_tanh((1))
self._test_ones_body_tanh((1, 10))
self._test_ones_body_tanh((2, 10, 2))
self._test_ones_body_tanh((2, 5, 2, 2))

def test_random_input_tanh(self):
self._test_random_body_tanh((1))
self._test_random_body_tanh((1, 10))
self._test_random_body_tanh((2, 10, 2))
self._test_random_body_tanh((2, 5, 2, 2))

def _test_body_tanh_v2(test_case, input_arr):
x = flow.Tensor(input_arr)

y = flow.tanh(x)
z = np.tanh(input_arr)

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))

def _test_body_tanh_v3(test_case, input_arr):
x = flow.Tensor(input_arr)

y = x.tanh()
z = np.tanh(input_arr)

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestGeLU(flow.unittest.TestCase):
def test_gelu_v1(test_case):
input_arr = np.array([-0.5, 0, 0.5]).astype(np.float32)
x = flow.Tensor(input_arr)

gelu = flow.nn.GELU()
y = gelu(x)
z = np.array([-0.15426877, 0.0, 0.34573123])

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))

def test_gelu_v2(test_case):
input_arr = np.array([-0.5, 0, 0.5]).astype(np.float32)
x = flow.Tensor(input_arr)

y = flow.gelu(x)
z = np.array([-0.15426877, 0.0, 0.34573123])

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))

def test_gelu_v3(test_case):
input_arr = np.array([-0.5, 0, 0.5]).astype(np.float32)
x = flow.Tensor(input_arr)

y = x.gelu()

z = np.array([-0.15426877, 0.0, 0.34573123])

test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions oneflow/python/test/modules/test_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest

import numpy as np
import oneflow as flow


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestExp(flow.unittest.TestCase):
def test_exp_v1(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = flow.exp(input)
np_out = np.exp(input.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))

def test_exp_v2(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input.exp()
np_out = np.exp(input.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4b876e4

Please sign in to comment.