Skip to content

Commit 11d6555

Browse files
committed
after comments
1 parent bf13e18 commit 11d6555

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

1808_Neural_Networks/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
5+
def draw_neural_net(ax, left, right, bottom, top, layer_sizes):
6+
'''
7+
Draw a neural network cartoon using matplotilb.
8+
9+
:usage:
10+
>>> fig = plt.figure(figsize=(12, 12))
11+
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
12+
13+
:parameters:
14+
- ax : matplotlib.axes.AxesSubplot
15+
The axes on which to plot the cartoon (get e.g. by plt.gca())
16+
- left : float
17+
The center of the leftmost node(s) will be placed here
18+
- right : float
19+
The center of the rightmost node(s) will be placed here
20+
- bottom : float
21+
The center of the bottommost node(s) will be placed here
22+
- top : float
23+
The center of the topmost node(s) will be placed here
24+
- layer_sizes : list of int
25+
List of layer sizes, including input and output dimensionality
26+
'''
27+
n_layers = len(layer_sizes)
28+
v_spacing = (top - bottom)/float(max(layer_sizes))
29+
h_spacing = (right - left)/float(len(layer_sizes) - 1)
30+
# Nodes
31+
for n, layer_size in enumerate(layer_sizes):
32+
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
33+
for m in range(layer_size):
34+
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
35+
color='w', ec='k', zorder=4)
36+
ax.add_artist(circle)
37+
# Edges
38+
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
39+
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
40+
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
41+
for m in range(layer_size_a):
42+
for o in range(layer_size_b):
43+
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
44+
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k')
45+
ax.add_artist(line)
46+
47+
return ax
48+

0 commit comments

Comments
 (0)