Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev type as module #5349

Merged
merged 43 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
eb694ab
refine and add test case
Flowingsun007 Jun 10, 2021
813f379
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
0aae06b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
4a38ccd
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
b58b849
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
b5e151a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
0ff1651
support ellipsis type slice
Flowingsun007 Jun 11, 2021
1276e65
refine
Flowingsun007 Jun 11, 2021
b9066f9
refine
Flowingsun007 Jun 11, 2021
8f81967
support slice assign ellipsis type
Flowingsun007 Jun 11, 2021
8f8cee2
refine
Flowingsun007 Jun 11, 2021
a79dcdf
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 12, 2021
81a400e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 12, 2021
e9e2fa4
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
dbdcf18
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b11243f
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
9c7185b
Merge branch 'master' into dev_fix_slice_bug
oneflow-ci-bot Jun 13, 2021
c8c78fb
register fn to localtensor
Flowingsun007 Jun 13, 2021
f565929
Merge branch 'dev_fix_slice_bug' of https://github.com/Oneflow-Inc/on…
Flowingsun007 Jun 13, 2021
ebc25f0
Merge branch 'dev_fix_slice_bug'
Flowingsun007 Jun 13, 2021
475bbff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b69f554
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 15, 2021
e8cd9e3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 16, 2021
a500a6d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
5387b8f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
756b0ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
34e9fd5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
a5d67ac
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
e547b4b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
756e537
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 24, 2021
a39271b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 25, 2021
d5ecb51
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 27, 2021
db1b536
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
75cc02b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
634b968
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
6be7d0b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
c3708d3
add type_as module
Flowingsun007 Jun 29, 2021
a4cef85
add long module
Flowingsun007 Jun 29, 2021
5278022
add more test
Flowingsun007 Jun 29, 2021
b2013d0
Merge branch 'master' into dev_type_as_module
oneflow-ci-bot Jun 29, 2021
9f78d24
Merge branch 'master' into dev_type_as_module
oneflow-ci-bot Jun 29, 2021
a2727f0
Merge branch 'master' into dev_type_as_module
oneflow-ci-bot Jun 29, 2021
ca6bced
Merge branch 'master' into dev_type_as_module
oneflow-ci-bot Jun 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,5 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.ZeroPad2d
.. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
.. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
.. autofunction:: oneflow.experimental.Tensor.type_as
.. autofunction:: oneflow.experimental.Tensor.long
94 changes: 94 additions & 0 deletions oneflow/python/nn/modules/tensor_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.oneflow_export import experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op


class TypeAs(Module):
def __init__(self):
super().__init__()

def forward(self, input, target):
return input.to(dtype=target.dtype)


@register_tensor_op("type_as")
@experimental_api
def type_as_op(input, target):
r"""Returns this tensor cast to the type of the given tensor.
This is a no-op if the tensor is already of the correct type.

Args:
input (Tensor): the input tensor.
target (Tensor): the tensor which has the desired type.

For example:

.. code-block:: python

>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()

>>> input = flow.Tensor(np.random.randn(1, 2, 3), dtype=flow.float32)
>>> target = flow.Tensor(np.random.randn(4, 5, 6), dtype = flow.int32)
>>> input = input.type_as(target)
>>> input.dtype
oneflow.int32

"""
return TypeAs()(input, target)


class Long(Module):
def __init__(self):
super().__init__()

def forward(self, input):
return input.to(dtype=flow.int64)


@register_tensor_op("long")
@experimental_api
def long_op(input):
r"""`Tensor.long()` is equivalent to `Tensor.to(flow.int64)`. See to().

Args:
input (Tensor): the input tensor.

For example:

.. code-block:: python

>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()

>>> input = flow.Tensor(np.random.randn(1, 2, 3), dtype=flow.float32)
>>> input = input.long()
>>> input.dtype
oneflow.int64

"""
return Long()(input)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
64 changes: 64 additions & 0 deletions oneflow/python/test/modules/test_tensor_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _test_type_as(test_case, shape, device, src_dtype, tgt_dtype):
np_input = np.random.rand(*shape)
input = flow.tensor(np_input, dtype=src_dtype, device=device)
target = flow.tensor(np_input, dtype=tgt_dtype, device=device)
input = input.type_as(target)
test_case.assertEqual(input.dtype, target.dtype)


def _test_long(test_case, shape, device, dtype):
np_input = np.random.rand(*shape)
input = flow.tensor(np_input, dtype=dtype, device=device)
input = input.long()
test_case.assertEqual(input.dtype, flow.int64)


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTensorOps(flow.unittest.TestCase):
def test_type_as(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["src_dtype"] = [flow.int64, flow.int32, flow.float32, flow.float64]
arg_dict["tgt_dtype"] = [flow.int64, flow.int32, flow.float32, flow.float64]
for arg in GenArgList(arg_dict):
_test_type_as(test_case, *arg)

def test_long(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["dtype"] = [flow.int64, flow.int32, flow.float32, flow.float64]
for arg in GenArgList(arg_dict):
_test_long(test_case, *arg)


if __name__ == "__main__":
unittest.main()