@@ -39,12 +39,24 @@ def _ortho_init(shape, dtype, partition_info=None):
39
39
return (scale * q [:shape [0 ], :shape [1 ]]).astype (np .float32 )
40
40
return _ortho_init
41
41
42
- def conv (x , scope , * , nf , rf , stride , pad = 'VALID' , init_scale = 1.0 ):
42
+ def conv (x , scope , * , nf , rf , stride , pad = 'VALID' , init_scale = 1.0 , data_format = 'NHWC' ):
43
+ if data_format == 'NHWC' :
44
+ channel_ax = 3
45
+ strides = [1 , stride , stride , 1 ]
46
+ bshape = [1 , 1 , 1 , nf ]
47
+ elif data_format == 'NCHW' :
48
+ channel_ax = 1
49
+ strides = [1 , 1 , stride , stride ]
50
+ bshape = [1 , nf , 1 , 1 ]
51
+ else :
52
+ raise NotImplementedError
53
+ nin = x .get_shape ()[channel_ax ].value
54
+ wshape = [rf , rf , nin , nf ]
43
55
with tf .variable_scope (scope ):
44
- nin = x . get_shape ()[ 3 ]. value
45
- w = tf .get_variable ("w " , [rf , rf , nin , nf ], initializer = ortho_init ( init_scale ))
46
- b = tf .get_variable ( "b" , [ nf ], initializer = tf . constant_initializer ( 0.0 ) )
47
- return tf .nn .conv2d (x , w , strides = [ 1 , stride , stride , 1 ], padding = pad ) + b
56
+ w = tf . get_variable ( "w" , wshape , initializer = ortho_init ( init_scale ))
57
+ b = tf .get_variable ("b " , [1 , nf , 1 , 1 ], initializer = tf . constant_initializer ( 0.0 ))
58
+ if data_format == 'NHWC' : b = tf .reshape ( b , bshape )
59
+ return b + tf .nn .conv2d (x , w , strides = strides , padding = pad , data_format = data_format )
48
60
49
61
def fc (x , scope , nh , * , init_scale = 1.0 , init_bias = 0.0 ):
50
62
with tf .variable_scope (scope ):
0 commit comments