Skip to content

Commit

Permalink
Add multilayer perceptron.
Browse files Browse the repository at this point in the history
  • Loading branch information
trekhleb committed Dec 20, 2018
1 parent 776f77f commit 80c5449
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 165 deletions.
133 changes: 113 additions & 20 deletions homemade/neural_network/multilayer_perceptron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from scipy.optimize import minimize
from ..utils.features import prepare_for_training
from ..utils.hypothesis import sigmoid
from ..utils.hypothesis import sigmoid, sigmoid_gradient


class MultilayerPerceptron:
Expand Down Expand Up @@ -76,7 +76,7 @@ def gradient_descent(data, labels, initial_theta, layers, regularization_param,
method='CG',
# Function that will help to calculate gradient direction on each step.
jac=lambda current_theta: MultilayerPerceptron.gradient_step(
data, labels, current_theta, regularization_param
data, labels, current_theta, layers, regularization_param
),
# Record gradient descent progress for debugging.
callback=lambda current_theta: cost_history.append(MultilayerPerceptron.cost_function(
Expand All @@ -94,20 +94,28 @@ def gradient_descent(data, labels, initial_theta, layers, regularization_param,
return optimized_theta, cost_history

@staticmethod
def gradient_step(unrolled_thetas, layers):
def gradient_step(data, labels, unrolled_thetas, layers, regularization_param):
"""Gradient step function.
Computes the cost and gradient of the neural network for unrolled theta parameters.
:param unrolled_thetas: flat vector of model parameters
:param layers: model layers configuration
:param data: training set.
:param labels: training set labels.
:param unrolled_thetas: flat vector of model parameters.
:param layers: model layers configuration.
:param regularization_param: parameters that fights with model over-fitting.
"""

# Reshape nn_params back into the matrix parameters.
thetas = MultilayerPerceptron.thetas_roll(unrolled_thetas, layers)

# Do backpropagation.
MultilayerPerceptron.back_propagation()
thetas_rolled_gradients = MultilayerPerceptron.back_propagation(
data, labels, thetas, layers, regularization_param
)

# Unroll thetas gradients.
return MultilayerPerceptron.thetas_unroll(thetas_rolled_gradients)

@staticmethod
def cost_function(data, labels, thetas, layers, regularization_param):
Expand Down Expand Up @@ -169,22 +177,107 @@ def feedforward_propagation(data, thetas, layers):
num_examples = data.shape[0]

# Input layer (l=1)
layer_in = data
in_layer_activation = data

# Propagate to hidden layers.
for layer_index in range(num_layers - 1):
theta = thetas[layer_index]
layer_out = sigmoid(layer_in @ theta.T)
out_layer_activation = sigmoid(in_layer_activation @ theta.T)
# Add bias units.
layer_out = np.hstack((np.ones((num_examples, 1)), layer_out))
layer_in = layer_out
out_layer_activation = np.hstack((np.ones((num_examples, 1)), out_layer_activation))
in_layer_activation = out_layer_activation

# Output layer should not contain bias units.
return layer_in[:, 1:]
return in_layer_activation[:, 1:]

@staticmethod
def back_propagation():
pass
def back_propagation(data, labels, thetas, layers, regularization_param):
"""Backpropagation function"""

# Get total number of layers.
num_layers = len(layers)

# Get total number of training examples and features.
(num_examples, num_features) = data.shape

# Get the number of possible output labels.
num_label_types = layers[-1]

# Initialize big delta - aggregated delta values for all training examples that will
# indicate how exact theta need to be changed.
deltas = {}
for layer_index in range(num_layers - 1):
in_count = layers[layer_index]
out_count = layers[layer_index + 1]
deltas[layer_index] = np.zeros((out_count, in_count + 1))

# Let's go through all examples.
for example_index in range(num_examples):
# We will store layers inputs and activations in order to re-use it later.
layers_inputs = {}
layers_activations = {}

# Setup input layer activations.
layer_activation = data[example_index, :].reshape((num_features, 1))
layers_activations[0] = layer_activation

# Perform a feedforward pass for current training example.
for layer_index in range(num_layers - 1):
layer_theta = thetas[layer_index]
layer_input = layer_theta @ layer_activation
layer_activation = np.vstack((np.array([[1]]), sigmoid(layer_input)))

layers_inputs[layer_index + 1] = layer_input
layers_activations[layer_index + 1] = layer_activation

# Remove bias units from the output activations.
output_layer_activation = layer_activation[1:, :]

# Calculate deltas.

# For input layer we don't calculate delta because we do not
# associate error with the input.
delta = {}

# Convert the output from number to vector (i.e. 5 to [0; 0; 0; 0; 1; 0; 0; 0; 0; 0])
bitwise_label = np.zeros((num_label_types, 1))
bitwise_label[labels[example_index][0]] = 1

# Calculate deltas for the output layer for current training example.
delta[num_layers - 1] = output_layer_activation - bitwise_label

# Calculate small deltas for hidden layers for current training example.
# The loops should go for the layers L, L-1, ..., 1.
for layer_index in range(num_layers - 2, 0, -1):
layer_theta = thetas[layer_index]
next_delta = delta[layer_index + 1]
layer_input = layers_inputs[layer_index]

# Add bias row to the layer input.
layer_input = np.vstack((np.array([[1]]), layer_input))

# Calculate row delta and take off the bias row from it.
delta[layer_index] = (layer_theta.T @ next_delta) * sigmoid_gradient(layer_input)
delta[layer_index] = delta[layer_index][1:, :]

# Accumulate the gradient (update big deltas).
for layer_index in range(num_layers - 1):
layer_delta = delta[layer_index + 1] @ layers_activations[layer_index].T
deltas[layer_index] = deltas[layer_index] + layer_delta

# Obtain un-regularized gradient for the neural network cost function.
for layer_index in range(num_layers - 1):
# Remember that we should NOT be regularizing the first column of theta.
current_delta = deltas[layer_index]
current_delta = np.hstack((np.zeros((current_delta.shape[0], 1)), current_delta[:, 1:]))

# Calculate regularization.
regularization = (regularization_param / num_examples) * current_delta

# Regularize deltas.
deltas[layer_index] = (1 / num_examples) * deltas[layer_index] + regularization

return deltas

@staticmethod
def thetas_init(layers, epsilon):
Expand All @@ -208,9 +301,9 @@ def thetas_init(layers, epsilon):
# Generate Thetas only for input and hidden layers.
# There is no need to generate Thetas for the output layer.
for layer_index in range(num_layers - 1):
layers_in = layers[layer_index]
layers_out = layers[layer_index + 1]
thetas[layer_index] = np.random.rand(layers_out, layers_in + 1) * 2 * epsilon - epsilon
in_count = layers[layer_index]
out_count = layers[layer_index + 1]
thetas[layer_index] = np.random.rand(out_count, in_count + 1) * 2 * epsilon - epsilon

return thetas

Expand Down Expand Up @@ -238,11 +331,11 @@ def thetas_roll(unrolled_thetas, layers):
unrolled_shift = 0

for layer_index in range(num_layers - 1):
layers_in = layers[layer_index]
layers_out = layers[layer_index + 1]
in_count = layers[layer_index]
out_count = layers[layer_index + 1]

thetas_width = layers_in + 1 # We need to remember about bias unit.
thetas_height = layers_out
thetas_width = in_count + 1 # We need to remember about bias unit.
thetas_height = out_count
thetas_volume = thetas_width * thetas_height

# We need to remember about bias units when rolling up params.
Expand Down
1 change: 1 addition & 0 deletions homemade/utils/hypothesis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .sigmoid import sigmoid
from .sigmoid_gradient import sigmoid_gradient
7 changes: 7 additions & 0 deletions homemade/utils/hypothesis/sigmoid_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .sigmoid import sigmoid


def sigmoid_gradient(z):
"""Computes the gradient of the sigmoid function evaluated at z."""

return sigmoid(z) * (1 - sigmoid(z))
148 changes: 3 additions & 145 deletions notebooks/neural_network/multilayer_perceptron_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -489,157 +489,15 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 125,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: array([[-0.0758517 , -0.10922462, 0.03927368, ..., -0.10116138,\n",
" -0.06562225, 0.08434671],\n",
" [ 0.0605631 , 0.07263726, -0.02723178, ..., -0.11961522,\n",
" 0.11282753, -0.02994306],\n",
" [ 0.11871606, 0.07256746, -0.01404285, ..., 0.10602606,\n",
" 0.10491185, -0.05867995],\n",
" ...,\n",
" [ 0.0919954 , -0.04510364, 0.03600806, ..., 0.11729077,\n",
" -0.05894698, 0.06732849],\n",
" [-0.11695147, 0.10688355, 0.0685417 , ..., -0.03359375,\n",
" -0.0944231 , 0.05190055],\n",
" [-0.01826086, 0.09777763, 0.08879194, ..., 0.0213223 ,\n",
" -0.01198016, 0.08821151]]), 1: array([[-0.0758517 , -0.10922462, 0.03927368, 0.04616654, -0.03550035,\n",
" -0.04301836, -0.05111002, -0.01779515, 0.01221125, 0.01600122,\n",
" -0.07813955, -0.04786295, 0.08427782, 0.09664291, 0.02393845,\n",
" -0.11691394, -0.09485863, 0.10624724, 0.04490283, -0.10627072,\n",
" 0.10434709, 0.09732701, 0.00282358, -0.00734285, -0.09995941,\n",
" -0.11459936],\n",
" [-0.06447097, 0.08244544, 0.1055696 , 0.07628717, 0.07190959,\n",
" -0.03122521, 0.11683788, -0.11039357, -0.06716706, 0.01594479,\n",
" -0.00133989, 0.03575698, 0.05225359, 0.00083552, 0.02347394,\n",
" -0.06132896, 0.05638961, -0.0044371 , -0.02010712, -0.10999934,\n",
" -0.04849349, 0.04404446, 0.07483567, -0.06039278, 0.04542441,\n",
" 0.08904211],\n",
" [-0.04615776, -0.07739299, -0.09182918, -0.02440465, -0.05477409,\n",
" 0.03588698, 0.02053208, -0.04363991, 0.04534481, 0.05309741,\n",
" -0.07172521, -0.01942062, -0.06344989, 0.09863689, -0.11781185,\n",
" 0.02971791, -0.02973962, 0.06304532, -0.07716626, -0.03389946,\n",
" -0.04426616, 0.03890041, -0.07181278, 0.02769418, 0.00820932,\n",
" 0.10949384],\n",
" [-0.0584883 , -0.05273799, -0.04073093, -0.08154635, 0.08999456,\n",
" -0.09110997, 0.09805592, 0.02330922, 0.07835466, -0.11295456,\n",
" -0.05768334, 0.00250513, -0.00909849, -0.00671458, 0.06267393,\n",
" 0.07735554, 0.05565781, -0.06221527, 0.10644233, 0.03333939,\n",
" 0.02334794, -0.01852243, 0.03946706, 0.11171577, 0.00829028,\n",
" -0.05008512],\n",
" [-0.00948572, 0.00763778, 0.07092984, 0.03798784, -0.07694375,\n",
" 0.05564401, 0.11472868, 0.11388296, 0.08657028, -0.01318174,\n",
" 0.02493628, 0.01862749, 0.01416905, -0.10815415, 0.08573075,\n",
" 0.02036101, 0.06934405, 0.11281956, 0.02856743, -0.06820671,\n",
" -0.08479958, 0.02668589, -0.05561203, 0.05716293, 0.11849236,\n",
" 0.05245313],\n",
" [ 0.11395225, -0.07448341, -0.11355455, 0.07997803, -0.02016351,\n",
" 0.02623673, -0.09786482, -0.08886998, -0.02424251, 0.06848556,\n",
" -0.11399175, 0.01630017, 0.00199946, 0.00148151, -0.03053501,\n",
" 0.05940618, -0.05865 , 0.06081712, -0.06157728, -0.11024059,\n",
" -0.0677528 , 0.06100844, 0.02996631, -0.03733193, 0.09442967,\n",
" 0.0271904 ],\n",
" [ 0.03050059, -0.03451764, 0.07158443, 0.0541165 , 0.01873904,\n",
" -0.05535262, 0.00458515, -0.09848468, -0.01277639, -0.10496153,\n",
" -0.06116952, -0.0284652 , 0.0300631 , -0.02659276, 0.09268343,\n",
" -0.08086429, -0.07301074, -0.03411321, 0.1054892 , 0.0424244 ,\n",
" 0.09827251, 0.03980845, -0.09431661, -0.0580831 , -0.04872072,\n",
" 0.106885 ],\n",
" [ 0.08076684, -0.00780408, 0.06917175, 0.10370648, -0.00244977,\n",
" -0.09103661, -0.03319441, -0.10700324, 0.03875014, -0.02056288,\n",
" -0.01949595, -0.05121848, 0.10714613, -0.00404258, 0.0173522 ,\n",
" -0.05759117, -0.08206716, 0.08263817, -0.00864864, -0.08316974,\n",
" 0.08279706, 0.04957311, 0.03934321, 0.05675562, 0.04299622,\n",
" 0.04064601],\n",
" [ 0.00825281, -0.07706374, -0.00922871, 0.05605853, 0.00982105,\n",
" -0.05653799, -0.06617444, -0.08152387, 0.09066151, 0.00207551,\n",
" -0.03963645, 0.09282233, 0.02758925, 0.01784172, 0.11217704,\n",
" 0.05094281, 0.08854876, -0.09565834, 0.00443037, -0.01511557,\n",
" 0.10326956, -0.06927156, -0.0166677 , 0.0913672 , 0.06746135,\n",
" -0.04688244],\n",
" [ 0.02260412, 0.00678681, 0.00549161, -0.11994145, 0.04870088,\n",
" -0.05051432, -0.1141186 , 0.06037819, 0.04170217, -0.0586402 ,\n",
" -0.10248884, 0.01742958, -0.01947546, 0.06129252, 0.07150439,\n",
" -0.06523626, 0.09166035, 0.09504693, -0.03253129, -0.06043063,\n",
" -0.0926532 , -0.11705144, 0.0379782 , -0.05661604, -0.11245252,\n",
" -0.1087203 ]])}\n",
"{0: array([[-0.0758517 , -0.10922462, 0.03927368, ..., -0.10116138,\n",
" -0.06562225, 0.08434671],\n",
" [ 0.0605631 , 0.07263726, -0.02723178, ..., -0.11961522,\n",
" 0.11282753, -0.02994306],\n",
" [ 0.11871606, 0.07256746, -0.01404285, ..., 0.10602606,\n",
" 0.10491185, -0.05867995],\n",
" ...,\n",
" [ 0.0919954 , -0.04510364, 0.03600806, ..., 0.11729077,\n",
" -0.05894698, 0.06732849],\n",
" [-0.11695147, 0.10688355, 0.0685417 , ..., -0.03359375,\n",
" -0.0944231 , 0.05190055],\n",
" [-0.01826086, 0.09777763, 0.08879194, ..., 0.0213223 ,\n",
" -0.01198016, 0.08821151]]), 1: array([[-0.0758517 , -0.10922462, 0.03927368, 0.04616654, -0.03550035,\n",
" -0.04301836, -0.05111002, -0.01779515, 0.01221125, 0.01600122,\n",
" -0.07813955, -0.04786295, 0.08427782, 0.09664291, 0.02393845,\n",
" -0.11691394, -0.09485863, 0.10624724, 0.04490283, -0.10627072,\n",
" 0.10434709, 0.09732701, 0.00282358, -0.00734285, -0.09995941,\n",
" -0.11459936],\n",
" [-0.06447097, 0.08244544, 0.1055696 , 0.07628717, 0.07190959,\n",
" -0.03122521, 0.11683788, -0.11039357, -0.06716706, 0.01594479,\n",
" -0.00133989, 0.03575698, 0.05225359, 0.00083552, 0.02347394,\n",
" -0.06132896, 0.05638961, -0.0044371 , -0.02010712, -0.10999934,\n",
" -0.04849349, 0.04404446, 0.07483567, -0.06039278, 0.04542441,\n",
" 0.08904211],\n",
" [-0.04615776, -0.07739299, -0.09182918, -0.02440465, -0.05477409,\n",
" 0.03588698, 0.02053208, -0.04363991, 0.04534481, 0.05309741,\n",
" -0.07172521, -0.01942062, -0.06344989, 0.09863689, -0.11781185,\n",
" 0.02971791, -0.02973962, 0.06304532, -0.07716626, -0.03389946,\n",
" -0.04426616, 0.03890041, -0.07181278, 0.02769418, 0.00820932,\n",
" 0.10949384],\n",
" [-0.0584883 , -0.05273799, -0.04073093, -0.08154635, 0.08999456,\n",
" -0.09110997, 0.09805592, 0.02330922, 0.07835466, -0.11295456,\n",
" -0.05768334, 0.00250513, -0.00909849, -0.00671458, 0.06267393,\n",
" 0.07735554, 0.05565781, -0.06221527, 0.10644233, 0.03333939,\n",
" 0.02334794, -0.01852243, 0.03946706, 0.11171577, 0.00829028,\n",
" -0.05008512],\n",
" [-0.00948572, 0.00763778, 0.07092984, 0.03798784, -0.07694375,\n",
" 0.05564401, 0.11472868, 0.11388296, 0.08657028, -0.01318174,\n",
" 0.02493628, 0.01862749, 0.01416905, -0.10815415, 0.08573075,\n",
" 0.02036101, 0.06934405, 0.11281956, 0.02856743, -0.06820671,\n",
" -0.08479958, 0.02668589, -0.05561203, 0.05716293, 0.11849236,\n",
" 0.05245313],\n",
" [ 0.11395225, -0.07448341, -0.11355455, 0.07997803, -0.02016351,\n",
" 0.02623673, -0.09786482, -0.08886998, -0.02424251, 0.06848556,\n",
" -0.11399175, 0.01630017, 0.00199946, 0.00148151, -0.03053501,\n",
" 0.05940618, -0.05865 , 0.06081712, -0.06157728, -0.11024059,\n",
" -0.0677528 , 0.06100844, 0.02996631, -0.03733193, 0.09442967,\n",
" 0.0271904 ],\n",
" [ 0.03050059, -0.03451764, 0.07158443, 0.0541165 , 0.01873904,\n",
" -0.05535262, 0.00458515, -0.09848468, -0.01277639, -0.10496153,\n",
" -0.06116952, -0.0284652 , 0.0300631 , -0.02659276, 0.09268343,\n",
" -0.08086429, -0.07301074, -0.03411321, 0.1054892 , 0.0424244 ,\n",
" 0.09827251, 0.03980845, -0.09431661, -0.0580831 , -0.04872072,\n",
" 0.106885 ],\n",
" [ 0.08076684, -0.00780408, 0.06917175, 0.10370648, -0.00244977,\n",
" -0.09103661, -0.03319441, -0.10700324, 0.03875014, -0.02056288,\n",
" -0.01949595, -0.05121848, 0.10714613, -0.00404258, 0.0173522 ,\n",
" -0.05759117, -0.08206716, 0.08263817, -0.00864864, -0.08316974,\n",
" 0.08279706, 0.04957311, 0.03934321, 0.05675562, 0.04299622,\n",
" 0.04064601],\n",
" [ 0.00825281, -0.07706374, -0.00922871, 0.05605853, 0.00982105,\n",
" -0.05653799, -0.06617444, -0.08152387, 0.09066151, 0.00207551,\n",
" -0.03963645, 0.09282233, 0.02758925, 0.01784172, 0.11217704,\n",
" 0.05094281, 0.08854876, -0.09565834, 0.00443037, -0.01511557,\n",
" 0.10326956, -0.06927156, -0.0166677 , 0.0913672 , 0.06746135,\n",
" -0.04688244],\n",
" [ 0.02260412, 0.00678681, 0.00549161, -0.11994145, 0.04870088,\n",
" -0.05051432, -0.1141186 , 0.06037819, 0.04170217, -0.0586402 ,\n",
" -0.10248884, 0.01742958, -0.01947546, 0.06129252, 0.07150439,\n",
" -0.06523626, 0.09166035, 0.09504693, -0.03253129, -0.06043063,\n",
" -0.0926532 , -0.11705144, 0.0379782 , -0.05661604, -0.11245252,\n",
" -0.1087203 ]])}\n"
"[1.54168371e-06 0.00000000e+00 0.00000000e+00 ... 2.03102558e-01\n",
" 1.36934643e-01 1.77471528e-01]\n"
]
}
],
Expand Down

0 comments on commit 80c5449

Please sign in to comment.