Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
gwding committed Nov 13, 2017
1 parent 11eca2c commit 54606eb
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions draw_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Black = 0.


def add_layer(patches, colors, size=(24,24), num=5,
def add_layer(patches, colors, size=(24, 24), num=5,
top_left=[0, 0],
loc_diff=[3, -3],
):
Expand All @@ -69,16 +69,18 @@ def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn,

start_loc = top_left_list[ind_bgn] \
+ (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \
+ np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),\
- start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])])
+ np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),
- start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])]
)




end_loc = top_left_list[ind_bgn + 1] \
+ (num_show_list[ind_bgn + 1] - 1) * np.array(loc_diff_list[ind_bgn + 1]) \
+ np.array([end_ratio[0] * size_list[ind_bgn+1][1],\
- end_ratio[1] * size_list[ind_bgn+1][0]])
+ (num_show_list[ind_bgn + 1] - 1) * np.array(
loc_diff_list[ind_bgn + 1]) \
+ np.array([end_ratio[0] * size_list[ind_bgn + 1][1],
- end_ratio[1] * size_list[ind_bgn + 1][0]])


patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0]))
Expand Down Expand Up @@ -116,7 +118,7 @@ def label(xy, text, xy_off=[0, 4]):

############################
# conv layers
size_list = [(32, 32), (18, 18), (10,10), (6,6), (4,4)]
size_list = [(32, 32), (18, 18), (10, 10), (6, 6), (4, 4)]
num_list = [3, 32, 32, 48, 48]
x_diff_list = [0, layer_width, layer_width, layer_width, layer_width]
text_list = ['Inputs'] + ['Feature\nmaps'] * (len(size_list) - 1)
Expand All @@ -136,21 +138,23 @@ def label(xy, text, xy_off=[0, 4]):
# in between layers
start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]
end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]
patch_size_list = [(5,5), (2,2), (5,5), (2,2)]
patch_size_list = [(5, 5), (2, 2), (5, 5), (2, 2)]
ind_bgn_list = range(len(patch_size_list))
text_list = ['Convolution', 'Max-pooling', 'Convolution', 'Max-pooling']

for ind in range(len(patch_size_list)):
add_mapping(patches, colors, start_ratio_list[ind], end_ratio_list[ind],
patch_size_list[ind], ind,
top_left_list, loc_diff_list, num_show_list, size_list)
add_mapping(
patches, colors, start_ratio_list[ind], end_ratio_list[ind],
patch_size_list[ind], ind,
top_left_list, loc_diff_list, num_show_list, size_list)
label(top_left_list[ind], text_list[ind] + '\n{}x{} kernel'.format(
patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[26, -65])
patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[26, -65]
)


############################
# fully connected layers
size_list = [(fc_unit_size,fc_unit_size)]*3
size_list = [(fc_unit_size, fc_unit_size)] * 3
num_list = [768, 500, 2]
num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list)))
x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width]
Expand Down

0 comments on commit 54606eb

Please sign in to comment.