Skip to content

Commit a55bbdf

Browse files
authored
Correct test cases that were never executed (#70206)
* Correct test cases that were never executed * change data type * skip the unit tests for complex types when using xpu
1 parent 13b97b7 commit a55bbdf

File tree

1 file changed

+46
-17
lines changed

1 file changed

+46
-17
lines changed

test/legacy_test/test_diag_v2.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def setUp(self):
3232
self.init_dtype()
3333
self.init_attrs()
3434
self.init_input_output()
35+
self.set_input_output()
3536

3637
def init_dtype(self):
3738
self.dtype = np.float64
@@ -41,47 +42,61 @@ def init_attrs(self):
4142
self.padding_value = 0.0
4243

4344
def init_input_output(self):
44-
x = np.random.rand(10, 10).astype(self.dtype)
45-
out = np.diag(x, self.offset)
45+
self.x = np.random.rand(10, 10).astype(self.dtype)
46+
self.out = np.diag(self.x, self.offset)
4647

48+
def set_input_output(self):
4749
self.attrs = {
4850
'offset': self.offset,
4951
'padding_value': self.padding_value,
5052
}
51-
self.inputs = {'X': x}
52-
self.outputs = {'Out': out}
53+
54+
self.inputs = {'X': self.x}
55+
self.outputs = {'Out': self.out}
5356

5457
def test_check_output(self):
5558
paddle.enable_static()
56-
self.check_output(check_pir=True, check_prim_pir=True)
59+
if self.dtype == np.complex64 or self.dtype == np.complex128:
60+
self.check_output(check_pir=True)
61+
else:
62+
self.check_output(check_pir=True, check_prim_pir=True)
5763

5864
def test_check_grad(self):
5965
paddle.enable_static()
60-
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)
66+
if self.dtype == np.complex64 or self.dtype == np.complex128:
67+
self.check_grad(['X'], 'Out', check_pir=True)
68+
else:
69+
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)
6170

6271

6372
class TestDiagV2OpCase1(TestDiagV2Op):
64-
def init_config(self):
73+
def init_attrs(self):
74+
super().init_attrs()
6575
self.offset = 1
66-
self.out = np.diag(self.x, self.offset)
6776

6877

6978
class TestDiagV2OpCase2(TestDiagV2Op):
70-
def init_config(self):
79+
def init_attrs(self):
80+
super().init_attrs()
7181
self.offset = -1
72-
self.out = np.diag(self.x, self.offset)
7382

7483

7584
class TestDiagV2OpCase3(TestDiagV2Op):
76-
def init_config(self):
77-
self.x = np.random.randint(-10, 10, size=(10, 10)).astype("float64")
85+
def init_input_output(self):
86+
self.x = np.random.randint(-10, 10, size=(10, 10)).astype(self.dtype)
7887
self.out = np.diag(self.x, self.offset)
7988

8089

8190
class TestDiagV2OpCase4(TestDiagV2Op):
82-
def init_config(self):
83-
self.x = np.random.rand(100)
91+
def init_dtype(self):
92+
self.dtype = np.float32
93+
94+
def init_attrs(self):
95+
super().init_attrs()
8496
self.padding_value = 2
97+
98+
def init_input_output(self):
99+
self.x = np.random.rand(100).astype(self.dtype)
85100
n = self.x.size
86101
self.out = (
87102
self.padding_value * np.ones((n, n))
@@ -353,21 +368,35 @@ def test_check_grad(self):
353368
)
354369

355370

371+
@unittest.skipIf(
372+
core.is_compiled_with_xpu(),
373+
"xpu does not support complex64",
374+
)
356375
class TestDiagV2Complex64OP(TestDiagV2Op):
357-
def init_config(self):
376+
def init_dtype(self):
377+
self.dtype = np.complex64
378+
379+
def init_input_output(self):
358380
self.x = (
359381
np.random.randint(-10, 10, size=(10, 10))
360382
+ 1j * np.random.randint(-10, 10, size=(10, 10))
361-
).astype("complex64")
383+
).astype(self.dtype)
362384
self.out = np.diag(self.x, self.offset)
363385

364386

387+
@unittest.skipIf(
388+
core.is_compiled_with_xpu(),
389+
"xpu does not support complex128",
390+
)
365391
class TestDiagV2Complex128OP(TestDiagV2Op):
392+
def init_dtype(self):
393+
self.dtype = np.complex128
394+
366395
def init_config(self):
367396
self.x = (
368397
np.random.randint(-10, 10, size=(10, 10))
369398
+ 1j * np.random.randint(-10, 10, size=(10, 10))
370-
).astype("complex128")
399+
).astype(self.dtype)
371400
self.out = np.diag(self.x, self.offset)
372401

373402

0 commit comments

Comments
 (0)