44
55
66def conv3dtranspose_forward_naive (input_ , filter_ , conv3dtranspose_param ):
7- # [2, 3, 5, 5, 5]
87 in_n , in_c , in_d , in_h , in_w = input_ .shape
9- # [3, 6, 3, 3, 3]
108 f_c , out_c , f_d , f_h , f_w = filter_ .shape
119 assert in_c == f_c
1210
1311 stride , pad = conv3dtranspose_param ['stride' ], conv3dtranspose_param ['pad' ]
1412 out_d = (in_d - 1 ) * stride [0 ] + f_d
1513 out_h = (in_h - 1 ) * stride [1 ] + f_h
1614 out_w = (in_w - 1 ) * stride [2 ] + f_w
17-
1815 out = np .zeros ((in_n , out_c , out_d , out_h , out_w ))
1916
2017 for n in range (in_n ):
@@ -33,23 +30,22 @@ def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param):
3330 j1 , j2 = j * stride [2 ], j * stride [2 ] + f_w
3431 out [n , k , d1 :d2 , i1 :i2 , j1 :j2 ] += tmp_out
3532
33+ out = out [:, :, pad [0 ]:out_d - pad [0 ], pad [1 ]:out_h - pad [1 ], pad [2 ]:out_w -
34+ pad [2 ]]
3635 return out
3736
3837
3938class TestConv3dTransposeOp (OpTest ):
4039 def setUp (self ):
4140 # init as conv transpose
4241 self .init_op_type ()
43-
44- # [2, 3, 5, 5, 5] -> kernel [3, 6, 3, 3, 3] -> output [2, 6, 7, 7, 7]
4542 self .init_test_case ()
4643
4744 conv3dtranspose_param = {'stride' : self .stride , 'pad' : self .pad }
4845 input_ = np .random .random (self .input_size ).astype ("float32" )
4946 filter_ = np .random .random (self .filter_size ).astype ("float32" )
5047 output = conv3dtranspose_forward_naive (
5148 input_ , filter_ , conv3dtranspose_param ).astype ("float32" )
52- # print 'deconv output py', output, output.shape
5349
5450 self .inputs = {'Input' : input_ , 'Filter' : filter_ }
5551 self .attrs = {
@@ -60,7 +56,6 @@ def setUp(self):
6056 self .outputs = {'Output' : output }
6157
6258 def test_check_output (self ):
63- print 'check output here'
6459 self .check_output ()
6560
6661 def test_check_grad (self ):
@@ -85,13 +80,33 @@ def init_test_case(self):
8580 self .pad = [0 , 0 , 0 ]
8681 self .stride = [1 , 1 , 1 ]
8782 self .dilations = [1 , 1 , 1 ]
88- self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCHW
83+ self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
8984 f_c = self .input_size [1 ]
9085 self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
9186
9287 def init_op_type (self ):
9388 self .op_type = "conv3d_transpose"
9489
9590
91+ class TestWithPad (TestConv3dTransposeOp ):
92+ def init_test_case (self ):
93+ self .pad = [1 , 1 , 1 ]
94+ self .stride = [1 , 1 , 1 ]
95+ self .dilations = [1 , 1 , 1 ]
96+ self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
97+ f_c = self .input_size [1 ]
98+ self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
99+
100+
101+ class TestWithStride (TestConv3dTransposeOp ):
102+ def init_test_case (self ):
103+ self .pad = [1 , 1 , 1 ]
104+ self .stride = [2 , 2 , 2 ]
105+ self .dilations = [1 , 1 , 1 ]
106+ self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
107+ f_c = self .input_size [1 ]
108+ self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
109+
110+
96111if __name__ == '__main__' :
97112 unittest .main ()
0 commit comments