@@ -39,13 +39,14 @@ def output_hist(out, lam, a, b):
3939
4040
4141class TestPoissonOp1 (OpTest ):
42+
4243 def setUp (self ):
4344 self .op_type = "poisson"
4445 self .config ()
4546
4647 self .attrs = {}
47- self .inputs = {'X' : np .full ([1024 , 1024 ], self .lam , dtype = self .dtype )}
48- self .outputs = {'Out' : np .ones ([1024 , 1024 ], dtype = self .dtype )}
48+ self .inputs = {'X' : np .full ([2048 , 1024 ], self .lam , dtype = self .dtype )}
49+ self .outputs = {'Out' : np .ones ([2048 , 1024 ], dtype = self .dtype )}
4950
5051 def config (self ):
5152 self .lam = 10
@@ -55,10 +56,8 @@ def config(self):
5556
5657 def verify_output (self , outs ):
5758 hist , prob = output_hist (np .array (outs [0 ]), self .lam , self .a , self .b )
58- self .assertTrue (
59- np .allclose (
60- hist , prob , rtol = 0.01 ),
61- "actual: {}, expected: {}" .format (hist , prob ))
59+ self .assertTrue (np .allclose (hist , prob , rtol = 0.01 ),
60+ "actual: {}, expected: {}" .format (hist , prob ))
6261
6362 def test_check_output (self ):
6463 self .check_output_customized (self .verify_output )
@@ -67,22 +66,23 @@ def test_check_grad_normal(self):
6766 self .check_grad (
6867 ['X' ],
6968 'Out' ,
70- user_defined_grads = [np .zeros (
71- [1024 , 1024 ], dtype = self .dtype )],
69+ user_defined_grads = [np .zeros ([2048 , 1024 ], dtype = self .dtype )],
7270 user_defined_grad_outputs = [
73- np .random .rand (1024 , 1024 ).astype (self .dtype )
71+ np .random .rand (2048 , 1024 ).astype (self .dtype )
7472 ])
7573
7674
7775class TestPoissonOp2 (TestPoissonOp1 ):
76+
7877 def config (self ):
7978 self .lam = 5
8079 self .a = 1
81- self .b = 9
80+ self .b = 8
8281 self .dtype = "float32"
8382
8483
8584class TestPoissonAPI (unittest .TestCase ):
85+
8686 def test_static (self ):
8787 with paddle .static .program_guard (paddle .static .Program (),
8888 paddle .static .Program ()):
0 commit comments