@@ -66,7 +66,7 @@ def info(self):
66
66
class ConvLayer (Layer ):
67
67
LayerTiming = Timing ()
68
68
69
- def __init__ (self , shape , stride = 1 , padding = "SAME" , parent = None ):
69
+ def __init__ (self , shape , stride = 1 , padding = None , parent = None ):
70
70
"""
71
71
:param shape: shape[0] = shape of previous layer c x h x w
72
72
shape[1] = shape of current layer's weight f x h x w
@@ -79,15 +79,19 @@ def __init__(self, shape, stride=1, padding="SAME", parent=None):
79
79
shape = _parent .shape
80
80
Layer .__init__ (self , shape )
81
81
self .stride = stride
82
+ if padding is None :
83
+ padding = "SAME"
82
84
if isinstance (padding , str ):
83
85
if padding .upper () == "VALID" :
84
86
self .padding = 0
85
87
self .pad_flag = "VALID"
86
88
else :
87
89
self .padding = self .pad_flag = "SAME"
88
- else :
89
- self .padding = int ( padding )
90
+ elif isinstance ( padding , int ) :
91
+ self .padding = padding
90
92
self .pad_flag = "VALID"
93
+ else :
94
+ raise ValueError ("Padding should be 'SAME' or 'VALID' or integer" )
91
95
self .parent = parent
92
96
if len (shape ) == 1 :
93
97
self .n_channels = self .n_filters = self .out_h = self .out_w = None
@@ -136,7 +140,7 @@ def __init__(self, shape, stride=1, padding="SAME"):
136
140
conv_layer .__init__ (self , shape , stride , padding )
137
141
138
142
def _conv (self , x , w ):
139
- return tf .nn .conv2d (x , w , strides = [self .stride ] * 4 , padding = self .pad_flag )
143
+ return tf .nn .conv2d (x , w , strides = [1 , self .stride , self . stride , 1 ] , padding = self .pad_flag )
140
144
141
145
def _activate (self , x , w , bias , predict ):
142
146
res = self ._conv (x , w ) + bias
@@ -380,25 +384,21 @@ def get_layer_by_name(self, name, parent, current_dimension, *args, **kwargs):
380
384
return _layer , (_current , _next )
381
385
382
386
if __name__ == '__main__' :
383
- with tf .Session ().as_default ():
387
+ with tf .Session ().as_default () as sess :
384
388
# NN Process
385
389
nn_x = np .array ([
386
390
[ 0 , 1 , 2 , 1 , 0 ],
387
391
[- 1 , - 2 , 0 , 2 , 1 ],
388
392
[ 0 , 1 , - 2 , - 1 , 2 ],
389
393
[ 1 , 2 , - 1 , 0 , - 2 ]
390
- ])
394
+ ], dtype = np . float32 )
391
395
nn_w = np .array ([
392
396
[- 2 , - 1 , 0 , 1 , 2 ],
393
397
[ 2 , 1 , 0 , - 1 , - 2 ]
394
- ]).T
395
- nn_b = 1
396
- nn_id = Identical ([nn_x .shape [1 ], 1 ])
397
- nn_r1 = nn_id .activate (nn_x , nn_w , nn_b )
398
- # nn_norm = Normalize(nn_id, [None, 2])
399
- # nn_norm.activate(nn_r1, None)
400
- print (nn_r1 .eval ())
401
-
398
+ ], dtype = np .float32 ).T
399
+ nn_b = 1.
400
+ nn_id = Identical ([nn_x .shape [1 ], 2 ])
401
+ print (nn_id .activate (nn_x , nn_w , nn_b ).eval ())
402
402
# CNN Process
403
403
conv_x = np .array ([
404
404
[
@@ -408,17 +408,17 @@ def get_layer_by_name(self, name, parent, current_dimension, *args, **kwargs):
408
408
[- 2 , 1 , - 1 , 0 ]
409
409
]
410
410
], dtype = np .float32 ).reshape (1 , 4 , 4 , 1 )
411
- # Using "VALID" Padding -> out_h = out_w = 2
412
- conv_id = ConvIdentical ([(conv_x .shape [1 :], [2 , 3 , 3 ])], padding = "VALID" )
413
411
conv_w = np .array ([
414
412
[[ 1 , 0 , 1 ],
415
413
[- 1 , 0 , 1 ],
416
414
[ 1 , 0 , - 1 ]],
417
415
[[0 , 1 , 0 ],
418
416
[1 , 0 , - 1 ],
419
417
[0 , - 1 , 1 ]]
420
- ]).transpose ([1 , 2 , 0 ])[..., None , :]
421
- conv_b = np .array ([1 , - 1 ])
418
+ ], dtype = np .float32 ).transpose ([1 , 2 , 0 ])[..., None , :]
419
+ conv_b = np .array ([1 , - 1 ], dtype = np .float32 )
420
+ # Using "VALID" Padding -> out_h = out_w = 2
421
+ conv_id = ConvIdentical ([(conv_x .shape [1 :], [2 , 3 , 3 ])], padding = "VALID" )
422
422
print (conv_id .activate (conv_x , conv_w , conv_b ).eval ())
423
423
conv_x = np .array ([
424
424
[
@@ -435,5 +435,5 @@ def get_layer_by_name(self, name, parent, current_dimension, *args, **kwargs):
435
435
[ 0 1 -1 2 0 ]
436
436
[ 0 0 0 0 0 ] ]
437
437
"""
438
- conv_id = ConvIdentical ([(conv_x .shape [1 :], [2 , 3 , 3 ])], padding = "SAME" )
438
+ conv_id = ConvIdentical ([(conv_x .shape [1 :], [2 , 3 , 3 ])], padding = 1 , stride = 2 )
439
439
print (conv_id .activate (conv_x , conv_w , conv_b ).eval ())
0 commit comments