@@ -32,6 +32,7 @@ def setUp(self):
32
32
self .init_dtype ()
33
33
self .init_attrs ()
34
34
self .init_input_output ()
35
+ self .set_input_output ()
35
36
36
37
def init_dtype (self ):
37
38
self .dtype = np .float64
@@ -41,47 +42,61 @@ def init_attrs(self):
41
42
self .padding_value = 0.0
42
43
43
44
def init_input_output (self ):
44
- x = np .random .rand (10 , 10 ).astype (self .dtype )
45
- out = np .diag (x , self .offset )
45
+ self . x = np .random .rand (10 , 10 ).astype (self .dtype )
46
+ self . out = np .diag (self . x , self .offset )
46
47
48
+ def set_input_output (self ):
47
49
self .attrs = {
48
50
'offset' : self .offset ,
49
51
'padding_value' : self .padding_value ,
50
52
}
51
- self .inputs = {'X' : x }
52
- self .outputs = {'Out' : out }
53
+
54
+ self .inputs = {'X' : self .x }
55
+ self .outputs = {'Out' : self .out }
53
56
54
57
def test_check_output (self ):
55
58
paddle .enable_static ()
56
- self .check_output (check_pir = True , check_prim_pir = True )
59
+ if self .dtype == np .complex64 or self .dtype == np .complex128 :
60
+ self .check_output (check_pir = True )
61
+ else :
62
+ self .check_output (check_pir = True , check_prim_pir = True )
57
63
58
64
def test_check_grad (self ):
59
65
paddle .enable_static ()
60
- self .check_grad (['X' ], 'Out' , check_pir = True , check_prim_pir = True )
66
+ if self .dtype == np .complex64 or self .dtype == np .complex128 :
67
+ self .check_grad (['X' ], 'Out' , check_pir = True )
68
+ else :
69
+ self .check_grad (['X' ], 'Out' , check_pir = True , check_prim_pir = True )
61
70
62
71
63
72
class TestDiagV2OpCase1 (TestDiagV2Op ):
64
- def init_config (self ):
73
+ def init_attrs (self ):
74
+ super ().init_attrs ()
65
75
self .offset = 1
66
- self .out = np .diag (self .x , self .offset )
67
76
68
77
69
78
class TestDiagV2OpCase2 (TestDiagV2Op ):
70
- def init_config (self ):
79
+ def init_attrs (self ):
80
+ super ().init_attrs ()
71
81
self .offset = - 1
72
- self .out = np .diag (self .x , self .offset )
73
82
74
83
75
84
class TestDiagV2OpCase3 (TestDiagV2Op ):
76
- def init_config (self ):
77
- self .x = np .random .randint (- 10 , 10 , size = (10 , 10 )).astype ("float64" )
85
+ def init_input_output (self ):
86
+ self .x = np .random .randint (- 10 , 10 , size = (10 , 10 )).astype (self . dtype )
78
87
self .out = np .diag (self .x , self .offset )
79
88
80
89
81
90
class TestDiagV2OpCase4 (TestDiagV2Op ):
82
- def init_config (self ):
83
- self .x = np .random .rand (100 )
91
+ def init_dtype (self ):
92
+ self .dtype = np .float32
93
+
94
+ def init_attrs (self ):
95
+ super ().init_attrs ()
84
96
self .padding_value = 2
97
+
98
+ def init_input_output (self ):
99
+ self .x = np .random .rand (100 ).astype (self .dtype )
85
100
n = self .x .size
86
101
self .out = (
87
102
self .padding_value * np .ones ((n , n ))
@@ -353,21 +368,35 @@ def test_check_grad(self):
353
368
)
354
369
355
370
371
+ @unittest .skipIf (
372
+ core .is_compiled_with_xpu (),
373
+ "xpu does not support complex64" ,
374
+ )
356
375
class TestDiagV2Complex64OP (TestDiagV2Op ):
357
- def init_config (self ):
376
+ def init_dtype (self ):
377
+ self .dtype = np .complex64
378
+
379
+ def init_input_output (self ):
358
380
self .x = (
359
381
np .random .randint (- 10 , 10 , size = (10 , 10 ))
360
382
+ 1j * np .random .randint (- 10 , 10 , size = (10 , 10 ))
361
- ).astype ("complex64" )
383
+ ).astype (self . dtype )
362
384
self .out = np .diag (self .x , self .offset )
363
385
364
386
387
+ @unittest .skipIf (
388
+ core .is_compiled_with_xpu (),
389
+ "xpu does not support complex128" ,
390
+ )
365
391
class TestDiagV2Complex128OP (TestDiagV2Op ):
392
+ def init_dtype (self ):
393
+ self .dtype = np .complex128
394
+
366
395
def init_config (self ):
367
396
self .x = (
368
397
np .random .randint (- 10 , 10 , size = (10 , 10 ))
369
398
+ 1j * np .random .randint (- 10 , 10 , size = (10 , 10 ))
370
- ).astype ("complex128" )
399
+ ).astype (self . dtype )
371
400
self .out = np .diag (self .x , self .offset )
372
401
373
402
0 commit comments