Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add floor module and the corresponding testcases #4964

Merged
merged 49 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
65ca0f0
first commit of floor
liamjxu May 24, 2021
98f93f4
first commit of floor
liamjxu May 24, 2021
3cf0429
Added function floor_op_tensor
liamjxu May 24, 2021
377a090
fixed a copy paste bug
liamjxu May 24, 2021
2032080
added test for backward_cpu
liamjxu May 24, 2021
7a81d98
Added more test cases, added doctest
liamjxu Jun 4, 2021
4fd64ee
fix conflicts
liamjxu Jun 7, 2021
aface00
fix conflicts
liamjxu Jun 7, 2021
ce3e948
changed doctest testmod, eliminated the print commands in doctest
liamjxu Jun 7, 2021
a2c41bd
Merge branch 'master' into dev_floor
liamjxu Jun 7, 2021
496121e
Merge branch 'master' into dev_floor
liamjxu Jun 7, 2021
e245905
made sure the math_op.py is the same as in the latest master branch
liamjxu Jun 8, 2021
3402bd3
Merge branch 'dev_floor' of https://github.com/Oneflow-Inc/oneflow in…
liamjxu Jun 8, 2021
77fe596
added an ending newline
liamjxu Jun 8, 2021
8db45ee
Merge branch 'master' into dev_floor
liamjxu Jun 8, 2021
dca3fa5
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 8, 2021
cf51964
auto format by CI
oneflow-ci-bot Jun 8, 2021
94c3ba0
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 8, 2021
f69b317
fix test_tensor python format
liamjxu Jun 8, 2021
3e2b70a
auto format by CI
oneflow-ci-bot Jun 8, 2021
705195f
fix license
liamjxu Jun 8, 2021
e27eb61
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 8, 2021
4157e5b
Merge branch 'master' into dev_floor
liamjxu Jun 9, 2021
9027fe0
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
08276ab
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
26d1bc2
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
c466751
fix docstring
liamjxu Jun 9, 2021
772ea2b
Merge branch 'dev_floor' of https://github.com/Oneflow-Inc/oneflow in…
liamjxu Jun 9, 2021
2fc2a62
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
7efcb18
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
009c5f0
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
64936d1
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
d701aa6
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
559cb2a
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
da13e92
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
67bbda7
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 9, 2021
090cc5a
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
d9bfc21
fix doc test format
liamjxu Jun 10, 2021
c47dab7
Merge branch 'dev_floor' of https://github.com/Oneflow-Inc/oneflow in…
liamjxu Jun 10, 2021
4810999
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
6faa606
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
971c675
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
a5e41a3
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
daf719e
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
d245377
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
ba2862b
Merge branch 'master' into dev_floor
oneflow-ci-bot Jun 10, 2021
6093c26
Merge branch 'master' into dev_floor
liamjxu Jun 14, 2021
4f2495b
auto format by CI
oneflow-ci-bot Jun 14, 2021
5e86416
Merge branch 'master' into dev_floor
liamjxu Jun 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.Upsample
.. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d
.. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d
.. autofunction:: oneflow.experimental.floor
.. autofunction:: oneflow.experimental.Tensor.floor
.. autofunction:: oneflow.experimental.addmm
.. autofunction:: oneflow.experimental.Tensor.addmm
.. autofunction:: oneflow.experimental.clamp
Expand Down
88 changes: 88 additions & 0 deletions oneflow/python/nn/modules/floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import collections
from typing import Optional, Sequence, Union

import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op
from oneflow.python.nn.modules.utils import _check_axis


class Floor(Module):
def __init__(self) -> None:
super().__init__()
self._op = flow.builtin_op("floor").Input("x").Output("y").Build()

def forward(self, x):
return self._op(x)[0]


@oneflow_export("floor")
@experimental_api
def floor_op(x):

r"""
Returns a new tensor with the arcsine of the elements of :attr:`input`.

.. math::
\text{out}_{i} = \lfloor \text{input}_{i} \rfloor

Args:
input (Tensor): the input tensor.

For example:

.. code-block:: python

>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> input = flow.Tensor(np.array([-0.5, 1.5, 0, 0.8]), dtype=flow.float32)
>>> output = flow.floor(input)
>>> output.shape
flow.Size([4])
>>> output.numpy()
array([-1., 1., 0., 0.], dtype=float32)

>>> input1 = flow.Tensor(np.array([[0.8, 1.0], [-0.6, 2.5]]), dtype=flow.float32)
>>> output1 = input1.floor()
>>> output1.shape
flow.Size([2, 2])
>>> output1.numpy()
array([[ 0., 1.],
[-1., 2.]], dtype=float32)

"""

return Floor()(x)


@register_tensor_op("floor")
@experimental_api
def floor_op_tensor(input):
r"""
See :func:`oneflow.experimental.floor`
"""
return Floor()(input)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
60 changes: 60 additions & 0 deletions oneflow/python/test/modules/test_floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _test_floor(test_case, shape, device):
np_input = np.random.randn(*shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)

of_out = flow.floor(of_input)
np_out = np.floor(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
)

of_out = of_out.sum()
of_out.backward()
np_out_grad = np.zeros_like(of_out, dtype=np.float32)

test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
)


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestFloor(flow.unittest.TestCase):
def test_floor(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2,), (2, 3), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_floor(test_case, *arg)


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions oneflow/python/test/modules/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,24 @@ def _test_topk_original(test_case, device):
)


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestPow(flow.unittest.TestCase):
def test_pow(test_case):
input = flow.Tensor(np.array([1, 2, 3, 4, 5, 6]), dtype=flow.float32)
of_out = flow.pow(input, 2.1)
np_out = np.power(input.numpy(), 2.1)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))

def test_pow_tensor_function(test_case):
input = flow.Tensor(np.array([1, 2, 3, 4, 5, 6]), dtype=flow.float32)
of_out = input.pow(2.1)
np_out = np.power(input.numpy(), 2.1)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
Expand Down
12 changes: 12 additions & 0 deletions oneflow/python/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,18 @@ def test_construct_small_tensor(test_case):
test_case.assertEqual(tensor.dtype, flow.float32)
test_case.assertTrue(np.allclose(tensor.numpy(), np.array(scalar), 1e-4, 1e-4))

@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
)
def test_floor(test_case):
input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
of_out = input.floor()
np_out = np.floor(input.numpy())
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
)

@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
Expand Down