Skip to content

Commit ae11334

Browse files
committed
modify as review
1 parent 760b5a5 commit ae11334

File tree

2 files changed

+34
-31
lines changed

2 files changed

+34
-31
lines changed

python/paddle/fluid/tests/unittests/test_determinant_op.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ def test_check_grad(self):
4040
pass
4141

4242
def init_data(self):
43-
self.case = np.random.randn(3, 3, 3, 3, 3).astype('float64')
43+
np.random.seed(0)
44+
self.case = np.random.rand(3, 3, 3, 3, 3).astype('float64')
4445
self.inputs = {'Input': self.case}
4546
self.target = np.linalg.det(self.case)
4647

4748

4849
class TestDeterminantOpCase1(TestDeterminantOp):
4950
def init_data(self):
50-
self.case = np.random.randn(3, 3, 3, 3).astype('float32')
51+
np.random.seed(0)
52+
self.case = np.random.rand(3, 3, 3, 3).astype(np.float32)
5153
self.inputs = {'Input': self.case}
5254
self.target = np.linalg.det(self.case)
5355

@@ -57,7 +59,8 @@ def test_check_grad(self):
5759

5860
class TestDeterminantOpCase2(TestDeterminantOp):
5961
def init_data(self):
60-
self.case = np.random.randint(0, 2, (4, 2, 4, 4)).astype('float64')
62+
np.random.seed(0)
63+
self.case = np.random.rand(4, 2, 4, 4).astype('float64')
6164
self.inputs = {'Input': self.case}
6265
self.target = np.linalg.det(self.case)
6366

@@ -68,7 +71,8 @@ def test_check_grad(self):
6871
class TestDeterminantAPI(unittest.TestCase):
6972
def setUp(self):
7073
self.shape = [3, 3, 3, 3]
71-
self.x = np.random.random((3, 3, 3, 3)).astype(np.float32)
74+
np.random.seed(0)
75+
self.x = np.random.rand(3, 3, 3, 3).astype(np.float32)
7276
self.place = paddle.CPUPlace()
7377

7478
def test_api_static(self):
@@ -106,22 +110,25 @@ def test_check_grad(self):
106110
pass
107111

108112
def init_data(self):
109-
self.case = np.random.randn(3, 3, 3, 3).astype('float64')
113+
np.random.seed(0)
114+
self.case = np.random.rand(3, 3, 3, 3).astype('float64')
110115
self.inputs = {'Input': self.case}
111116
self.target = np.array(np.linalg.slogdet(self.case))
112117

113118

114119
class TestSlogDeterminantOpCase1(TestSlogDeterminantOp):
115120
def init_data(self):
116-
self.case = np.random.randn(3, 3, 3, 3).astype('float32')
121+
np.random.seed(0)
122+
self.case = np.random.rand(2, 2, 5, 5).astype(np.float32)
117123
self.inputs = {'Input': self.case}
118124
self.target = np.array(np.linalg.slogdet(self.case))
119125

120126

121127
class TestSlogDeterminantAPI(unittest.TestCase):
122128
def setUp(self):
123129
self.shape = [3, 3, 3, 3]
124-
self.x = np.random.random((3, 3, 3, 3)).astype(np.float32)
130+
np.random.seed(0)
131+
self.x = np.random.rand(3, 3, 3, 3).astype(np.float32)
125132
self.place = paddle.CPUPlace()
126133

127134
def test_api_static(self):

python/paddle/tensor/linalg.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,21 +1378,19 @@ def det(x):
13781378
if in_dygraph_mode():
13791379
return core.ops.determinant(x)
13801380

1381-
def _check_input(input):
1382-
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det')
1381+
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det')
13831382

1384-
input_shape = list(x.shape)
1385-
assert len(input_shape) >= 2, \
1386-
"The x must be at least 2-dimensional, " \
1387-
"but received Input x's dimensional: %s.\n" % \
1388-
len(input_shape)
1383+
input_shape = list(x.shape)
1384+
assert len(input_shape) >= 2, \
1385+
"The x must be at least 2-dimensional, " \
1386+
"but received Input x's dimensional: %s.\n" % \
1387+
len(input_shape)
13891388

1390-
assert (input_shape[-1] == input_shape[-2]), \
1391-
"Expect squared input," \
1392-
"but received %s by %s matrix.\n" \
1393-
%(input_shape[-2], input_shape[-1]) \
1389+
assert (input_shape[-1] == input_shape[-2]), \
1390+
"Expect squared input," \
1391+
"but received %s by %s matrix.\n" \
1392+
%(input_shape[-2], input_shape[-1]) \
13941393

1395-
_check_input(input)
13961394
helper = LayerHelper('determinant', **locals())
13971395
out = helper.create_variable_for_type_inference(dtype=x.dtype)
13981396

@@ -1435,21 +1433,19 @@ def slogdet(x):
14351433
if in_dygraph_mode():
14361434
return core.ops.slogdeterminant(x)
14371435

1438-
def _check_input(x):
1439-
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet')
1436+
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'slogdet')
14401437

1441-
input_shape = list(x.shape)
1442-
assert len(input_shape) >= 2, \
1443-
"The x must be at least 2-dimensional, " \
1444-
"but received Input x's dimensional: %s.\n" % \
1445-
len(input_shape)
1438+
input_shape = list(x.shape)
1439+
assert len(input_shape) >= 2, \
1440+
"The x must be at least 2-dimensional, " \
1441+
"but received Input x's dimensional: %s.\n" % \
1442+
len(input_shape)
14461443

1447-
assert (input_shape[-1] == input_shape[-2]), \
1448-
"Expect squared input," \
1449-
"but received %s by %s matrix.\n" \
1450-
%(input_shape[-2], input_shape[-1]) \
1444+
assert (input_shape[-1] == input_shape[-2]), \
1445+
"Expect squared input," \
1446+
"but received %s by %s matrix.\n" \
1447+
%(input_shape[-2], input_shape[-1]) \
14511448

1452-
_check_input(x)
14531449
helper = LayerHelper('slogdeterminant', **locals())
14541450
out = helper.create_variable_for_type_inference(dtype=x.dtype)
14551451

0 commit comments

Comments
 (0)