Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some op test #6095

Merged
merged 8 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/oneflow/test/modules/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict

import numpy as np
from test_util import GenArgList

import oneflow as flow
import oneflow.unittest
from automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestCat(flow.unittest.TestCase):
@autotest()
def test_cat_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=2, dim0=random(), dim1=random()).to(device)
return torch.cat((x, x, x), random(0, 2).to(int))


if __name__ == "__main__":
unittest.main()
138 changes: 13 additions & 125 deletions python/oneflow/test/modules/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,134 +22,22 @@

import oneflow as flow
import oneflow.unittest


def _test_diag_forward(test_case, shape, diagonal, device):
input = flow.Tensor(np.random.randn(*shape), device=flow.device(device))
of_out = flow.diag(input, diagonal)
np_out = np.diag(input.numpy(), diagonal)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
test_case.assertTrue(
np.allclose(
input.diag(diagonal=diagonal).numpy(), np_out, 1e-05, 1e-05, equal_nan=True
)
)


def _test_diag_one_dim_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
np_grad = np.ones(shape=3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
np_grad = np.ones(shape=3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)


def _test_diag_other_dim_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3, 3), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
if diagonal > 0:
np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]])
elif diagonal < 0:
np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
else:
np_grad = np.identity(3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3, 3), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
if diagonal > 0:
np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]])
elif diagonal < 0:
np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
else:
np_grad = np.identity(3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)


def _test_diag_other_dim_non_square_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3, 4), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
if diagonal > 0:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_tmp, np_grad))
elif diagonal < 0:
np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0]])
else:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_grad, np_tmp))
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3, 4), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
if diagonal > 0:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_tmp, np_grad))
elif diagonal < 0:
np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0]])
else:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_grad, np_tmp))
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)
from automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestDiag(flow.unittest.TestCase):
def test_diag_forward(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(3,), (3, 3), (3, 4)]
arg_dict["diagonal"] = [1, 0, -1]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_diag_forward(test_case, *arg[0:])

def test_diag_backward(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_diag_one_dim_backward,
_test_diag_other_dim_backward,
_test_diag_other_dim_non_square_backward,
]
arg_dict["diagonal"] = [1, 0, -1]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class Test_Diag_module(flow.unittest.TestCase):
@autotest()
def test_diag_one_dim(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=1, dim0=random()).to(device)
return torch.diag(x)

@autotest()
def test_diag_other_dim(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=2, dim0=random(), dim1=random()).to(device)
return torch.diag(x)


if __name__ == "__main__":
Expand Down
103 changes: 22 additions & 81 deletions python/oneflow/test/modules/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,91 +22,32 @@

import oneflow as flow
import oneflow.unittest


def _test_dropout(test_case, shape, device):
input_arr = np.random.randn(*shape)
m = flow.nn.Dropout(p=0)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), input_arr))


def _test_dropout_p1(test_case, shape, device):
input_arr = np.random.randn(*shape)
m = flow.nn.Dropout(p=1.0)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(
np.allclose(y.numpy(), np.zeros(input_arr.shape, dtype=np.float32))
)


def _test_dropout_backward_p0(test_case, shape, device):
input_arr = np.random.randn(*shape)
m = flow.nn.Dropout(p=0)
x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
y = m(x)
z = y.sum()
z.backward()
test_case.assertTrue(
np.allclose(
x.grad.numpy(), np.ones(input_arr.shape, dtype=np.float32), 1e-05, 1e-05
)
)


def _test_dropout_backward_p1(test_case, shape, device):
input_arr = np.random.randn(*shape)
m = flow.nn.Dropout(p=1)
x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
y = m(x)
z = y.sum()
z.backward()
test_case.assertTrue(
np.allclose(
x.grad.numpy(), np.zeros(input_arr.shape, dtype=np.float32), 1e-05, 1e-05
)
)


def _test_dropout_eval(test_case, shape, device):
input_arr = np.random.randn(*shape)
m = flow.nn.Dropout(p=1)
x = flow.Tensor(input_arr, device=flow.device(device))
m.eval()
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), input_arr))


def _test_dropout_with_generator(test_case, shape, device):
generator = flow.Generator()
generator.manual_seed(0)
m = flow.nn.Dropout(p=0.5, generator=generator)
x = flow.Tensor(np.random.randn(*shape), device=flow.device(device))
y_1 = m(x)
y_1.numpy()
generator.manual_seed(0)
y_2 = m(x)
test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy()))
from automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestDropout(flow.unittest.TestCase):
def test_transpose(test_case):
arg_dict = OrderedDict()
arg_dict["test_functions"] = [
_test_dropout,
_test_dropout_p1,
_test_dropout_backward_p0,
_test_dropout_backward_p1,
_test_dropout_eval,
_test_dropout_with_generator,
]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_dropout(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=random(), dim0=random()).to(device)
m = torch.nn.Dropout(p=0)
return m(x)

@autotest()
def test_dropout_p1(test_case):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测一下p!=0&&p!=1的随机值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个dropout随机drop其tensor,所以没办法确保pytorch和oneflow的tensor被相同的drop,只能测试p=0和p=1的情况

device = random_device()
x = random_pytorch_tensor(ndim=random(), dim0=random()).to(device)
m = torch.nn.Dropout(p=1.0)
return m(x)

@autotest()
def test_dropout_eval(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=random(), dim0=random()).to(device)
m = torch.nn.Dropout(p=1.0)
m.eval()
return m(x)


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions python/oneflow/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,20 @@ def test_round_tensor_with_random_data(test_case):
x = random_pytorch_tensor().to(device)
return x.round()

@flow.unittest.skip_unless_1n1d()
@autotest()
def test_tensor_diag_one_dim(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=1, dim0=random()).to(device)
return x.diag()

@flow.unittest.skip_unless_1n1d()
@autotest()
def test_tensor_diag_other_dim(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=2, dim0=random(), dim1=random()).to(device)
return x.diag()


if __name__ == "__main__":
unittest.main()