diff --git a/homemade/neural_network/multilayer_perceptron.py b/homemade/neural_network/multilayer_perceptron.py index ae4beb9..c4d6943 100644 --- a/homemade/neural_network/multilayer_perceptron.py +++ b/homemade/neural_network/multilayer_perceptron.py @@ -26,81 +26,88 @@ def __init__(self, data, labels, layers, epsilon, normalize_data=False): self.epsilon = epsilon # Randomly initialize the weights for each neural network layer. - self.thetas = MultilayerPerceptron.init_layers_thetas(layers, epsilon) + self.thetas = MultilayerPerceptron.thetas_init(layers, epsilon) def train(self, regularization_param=0, max_iterations=1000): # Flatten model thetas for gradient descent. - unrolled_thetas = MultilayerPerceptron.unroll_thetas(self.thetas) + unrolled_thetas = MultilayerPerceptron.thetas_unroll(self.thetas) # Init cost history array. cost_histories = [] - initial_cost = MultilayerPerceptron.cost_function( + # Run gradient descent. + (current_theta, cost_history) = MultilayerPerceptron.gradient_descent( self.data, self.labels, - self.thetas, + unrolled_thetas, self.layers, - regularization_param + regularization_param, + max_iterations, ) - print(initial_cost) + return self.thetas, cost_histories - # Run gradient descent. - # (current_theta, cost_history) = MultilayerPerceptron.gradient_descent( - # self.data, - # current_labels, - # unrolled_thetas, - # regularization_param, - # max_iterations, - # ) + @staticmethod + def gradient_descent(data, labels, initial_theta, layers, regularization_param, max_iteration): + """Gradient descent function. - return self.thetas, cost_histories + Iteratively optimizes theta model parameters. - # @staticmethod - # def gradient_descent(data, labels, initial_theta, lambda_param, max_iteration): - # """Gradient descent function. - # - # Iteratively optimizes theta model parameters. - # - # :param data: the set of training or test data. - # :param labels: training set outputs (0 or 1 that defines the class of an example). - # :param initial_theta: initial model parameters. - # :param lambda_param: regularization parameter. - # :param max_iteration: maximum number of gradient descent steps. - # """ - # - # # Initialize cost history list. - # cost_history = [] - # - # # Launch gradient descent. - # minification_result = minimize( - # # Function that we're going to minimize. - # lambda current_theta: MultilayerPerceptron.cost_function( - # data, labels, current_theta.reshape((num_features, 1)), lambda_param - # ), - # # Initial values of model parameter. - # initial_theta, - # # We will use conjugate gradient algorithm. - # method='CG', - # # Function that will help to calculate gradient direction on each step. - # jac=lambda current_theta: MultilayerPerceptron.gradient_step( - # data, labels, current_theta.reshape((num_features, 1)), lambda_param - # ), - # # Record gradient descent progress for debugging. - # callback=lambda current_theta: cost_history.append(MultilayerPerceptron.cost_function( - # data, labels, current_theta.reshape((num_features, 1)), lambda_param - # )), - # options={'maxiter': max_iteration} - # ) - # - # # Throw an error in case if gradient descent ended up with error. - # if not minification_result.success: - # raise ArithmeticError('Can not minimize cost function: ' + minification_result.message) - # - # # Reshape the final version of model parameters. - # optimized_theta = minification_result.x.reshape((num_features, 1)) - # - # return optimized_theta, cost_history + :param data: the set of training or test data. + :param labels: training set outputs (0 or 1 that defines the class of an example). + :param initial_theta: initial model parameters. + :param layers: model layers configuration. + :param regularization_param: regularization parameter. + :param max_iteration: maximum number of gradient descent steps. + """ + + # Initialize cost history list. + cost_history = [] + + # Launch gradient descent. + minification_result = minimize( + # Function that we're going to minimize. + lambda current_theta: MultilayerPerceptron.cost_function( + data, labels, current_theta, layers, regularization_param + ), + # Initial values of model parameter. + initial_theta, + # We will use conjugate gradient algorithm. + method='CG', + # Function that will help to calculate gradient direction on each step. + jac=lambda current_theta: MultilayerPerceptron.gradient_step( + data, labels, current_theta, regularization_param + ), + # Record gradient descent progress for debugging. + callback=lambda current_theta: cost_history.append(MultilayerPerceptron.cost_function( + data, labels, current_theta, layers, regularization_param + )), + options={'maxiter': max_iteration} + ) + + # Throw an error in case if gradient descent ended up with error. + if not minification_result.success: + raise ArithmeticError('Can not minimize cost function: ' + minification_result.message) + + optimized_theta = minification_result.x + + return optimized_theta, cost_history + + @staticmethod + def gradient_step(unrolled_thetas, layers): + """Gradient step function. + + Computes the cost and gradient of the neural network for unrolled theta parameters. + + :param unrolled_thetas: flat vector of model parameters + :param layers: model layers configuration + """ + + # Reshape nn_params back into the matrix parameters. + thetas = MultilayerPerceptron.thetas_roll(unrolled_thetas, layers) + + # Do backpropagation. + MultilayerPerceptron.back_propagation() @staticmethod def cost_function(data, labels, thetas, layers, regularization_param): @@ -176,7 +183,11 @@ def feedforward_propagation(data, thetas, layers): return layer_in[:, 1:] @staticmethod - def init_layers_thetas(layers, epsilon): + def back_propagation(): + pass + + @staticmethod + def thetas_init(layers, epsilon): """Randomly initialize the weights for each neural network layer Each layer will have its own theta matrix W with L_in incoming connections and L_out @@ -204,13 +215,40 @@ def init_layers_thetas(layers, epsilon): return thetas @staticmethod - def unroll_thetas(thetas): + def thetas_unroll(thetas): """Unrolls cells of theta matrices into one long vector.""" - unrolled_thetas = [] + unrolled_thetas = np.array([]) num_theta_layers = len(thetas) for theta_layer_index in range(num_theta_layers): # Unroll cells into vector form. - unrolled_thetas.extend(thetas[theta_layer_index].flatten()) + unrolled_thetas = np.hstack((unrolled_thetas, thetas[theta_layer_index].flatten())) return unrolled_thetas + + @staticmethod + def thetas_roll(unrolled_thetas, layers): + """Rolls NN params vector into the matrix""" + + # Get total numbers of layers. + num_layers = len(layers) + + # Init rolled thetas dictionary. + thetas = {} + unrolled_shift = 0 + + for layer_index in range(num_layers - 1): + layers_in = layers[layer_index] + layers_out = layers[layer_index + 1] + + thetas_width = layers_in + 1 # We need to remember about bias unit. + thetas_height = layers_out + thetas_volume = thetas_width * thetas_height + + # We need to remember about bias units when rolling up params. + start_index = unrolled_shift + end_index = unrolled_shift + thetas_volume + layer_thetas_unrolled = unrolled_thetas[start_index:end_index] + thetas[layer_index] = layer_thetas_unrolled.reshape((thetas_height, thetas_width)) + + return thetas diff --git a/notebooks/neural_network/multilayer_perceptron_demo.ipynb b/notebooks/neural_network/multilayer_perceptron_demo.ipynb index b613794..85951a1 100644 --- a/notebooks/neural_network/multilayer_perceptron_demo.ipynb +++ b/notebooks/neural_network/multilayer_perceptron_demo.ipynb @@ -489,44 +489,157 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[[ 102.31339018 -6.76123322 -4.57189726 14.10229705 -73.96248732\n", - " 16.94535105 51.70631812 6.76591175 2.10871692 -51.2213383 ]\n", - " [ 143.04364768 62.52334678 -35.13380338 -72.37149262 -67.58222084\n", - " 78.15070834 -52.43898597 57.55697067 -24.01268273 21.32550997]\n", - " [ 40.48651724 -15.16762216 25.12839925 -76.85728019 -66.48742905\n", - " -57.89230873 -78.90880348 1.56053988 48.02642381 -16.22590966]\n", - " [ 11.12371159 5.34899303 -22.10351971 -3.56414478 -93.85208458\n", - " 26.27924956 23.85632259 76.42730283 42.17351271 7.57397423]\n", - " [ 114.48011463 83.99486219 -3.04364227 -44.90837207 -116.05946947\n", - " -24.49775454 -33.44817482 124.24416151 44.96630732 24.16497151]\n", - " [ 80.75529319 -98.45010281 132.32959885 -42.31823661 -131.50819998\n", - " -83.16090932 -58.15866838 -29.01484542 1.55677062 44.90781698]\n", - " [ 217.97454053 27.81709144 67.03300985 -87.71948122 -171.12678829\n", - " 52.48712397 -156.14938905 70.47243533 65.30723852 66.87398418]\n", - " [ -43.96045999 82.34816008 13.46018172 -47.1095358 -86.70728225\n", - " -25.75972041 34.01783966 67.35552122 60.03389967 -12.97788424]\n", - " [ 23.22561986 53.42463592 1.56953998 22.60656285 -51.42582099\n", - " -20.61327855 -34.23162646 83.86376919 71.11336102 -17.63475097]\n", - " [ 59.73964492 -42.37896863 59.07231148 -28.72131355 -73.27614391\n", - " -57.64091183 -79.8853578 27.52059515 63.30527676 61.34695441]]\n", - "nan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "../../homemade/neural_network/multilayer_perceptron.py:149: RuntimeWarning: invalid value encountered in log\n", - " \n", - "../../homemade/neural_network/multilayer_perceptron.py:150: RuntimeWarning: invalid value encountered in log\n", - " # Calculate the cost with regularization.\n" + "{0: array([[-0.0758517 , -0.10922462, 0.03927368, ..., -0.10116138,\n", + " -0.06562225, 0.08434671],\n", + " [ 0.0605631 , 0.07263726, -0.02723178, ..., -0.11961522,\n", + " 0.11282753, -0.02994306],\n", + " [ 0.11871606, 0.07256746, -0.01404285, ..., 0.10602606,\n", + " 0.10491185, -0.05867995],\n", + " ...,\n", + " [ 0.0919954 , -0.04510364, 0.03600806, ..., 0.11729077,\n", + " -0.05894698, 0.06732849],\n", + " [-0.11695147, 0.10688355, 0.0685417 , ..., -0.03359375,\n", + " -0.0944231 , 0.05190055],\n", + " [-0.01826086, 0.09777763, 0.08879194, ..., 0.0213223 ,\n", + " -0.01198016, 0.08821151]]), 1: array([[-0.0758517 , -0.10922462, 0.03927368, 0.04616654, -0.03550035,\n", + " -0.04301836, -0.05111002, -0.01779515, 0.01221125, 0.01600122,\n", + " -0.07813955, -0.04786295, 0.08427782, 0.09664291, 0.02393845,\n", + " -0.11691394, -0.09485863, 0.10624724, 0.04490283, -0.10627072,\n", + " 0.10434709, 0.09732701, 0.00282358, -0.00734285, -0.09995941,\n", + " -0.11459936],\n", + " [-0.06447097, 0.08244544, 0.1055696 , 0.07628717, 0.07190959,\n", + " -0.03122521, 0.11683788, -0.11039357, -0.06716706, 0.01594479,\n", + " -0.00133989, 0.03575698, 0.05225359, 0.00083552, 0.02347394,\n", + " -0.06132896, 0.05638961, -0.0044371 , -0.02010712, -0.10999934,\n", + " -0.04849349, 0.04404446, 0.07483567, -0.06039278, 0.04542441,\n", + " 0.08904211],\n", + " [-0.04615776, -0.07739299, -0.09182918, -0.02440465, -0.05477409,\n", + " 0.03588698, 0.02053208, -0.04363991, 0.04534481, 0.05309741,\n", + " -0.07172521, -0.01942062, -0.06344989, 0.09863689, -0.11781185,\n", + " 0.02971791, -0.02973962, 0.06304532, -0.07716626, -0.03389946,\n", + " -0.04426616, 0.03890041, -0.07181278, 0.02769418, 0.00820932,\n", + " 0.10949384],\n", + " [-0.0584883 , -0.05273799, -0.04073093, -0.08154635, 0.08999456,\n", + " -0.09110997, 0.09805592, 0.02330922, 0.07835466, -0.11295456,\n", + " -0.05768334, 0.00250513, -0.00909849, -0.00671458, 0.06267393,\n", + " 0.07735554, 0.05565781, -0.06221527, 0.10644233, 0.03333939,\n", + " 0.02334794, -0.01852243, 0.03946706, 0.11171577, 0.00829028,\n", + " -0.05008512],\n", + " [-0.00948572, 0.00763778, 0.07092984, 0.03798784, -0.07694375,\n", + " 0.05564401, 0.11472868, 0.11388296, 0.08657028, -0.01318174,\n", + " 0.02493628, 0.01862749, 0.01416905, -0.10815415, 0.08573075,\n", + " 0.02036101, 0.06934405, 0.11281956, 0.02856743, -0.06820671,\n", + " -0.08479958, 0.02668589, -0.05561203, 0.05716293, 0.11849236,\n", + " 0.05245313],\n", + " [ 0.11395225, -0.07448341, -0.11355455, 0.07997803, -0.02016351,\n", + " 0.02623673, -0.09786482, -0.08886998, -0.02424251, 0.06848556,\n", + " -0.11399175, 0.01630017, 0.00199946, 0.00148151, -0.03053501,\n", + " 0.05940618, -0.05865 , 0.06081712, -0.06157728, -0.11024059,\n", + " -0.0677528 , 0.06100844, 0.02996631, -0.03733193, 0.09442967,\n", + " 0.0271904 ],\n", + " [ 0.03050059, -0.03451764, 0.07158443, 0.0541165 , 0.01873904,\n", + " -0.05535262, 0.00458515, -0.09848468, -0.01277639, -0.10496153,\n", + " -0.06116952, -0.0284652 , 0.0300631 , -0.02659276, 0.09268343,\n", + " -0.08086429, -0.07301074, -0.03411321, 0.1054892 , 0.0424244 ,\n", + " 0.09827251, 0.03980845, -0.09431661, -0.0580831 , -0.04872072,\n", + " 0.106885 ],\n", + " [ 0.08076684, -0.00780408, 0.06917175, 0.10370648, -0.00244977,\n", + " -0.09103661, -0.03319441, -0.10700324, 0.03875014, -0.02056288,\n", + " -0.01949595, -0.05121848, 0.10714613, -0.00404258, 0.0173522 ,\n", + " -0.05759117, -0.08206716, 0.08263817, -0.00864864, -0.08316974,\n", + " 0.08279706, 0.04957311, 0.03934321, 0.05675562, 0.04299622,\n", + " 0.04064601],\n", + " [ 0.00825281, -0.07706374, -0.00922871, 0.05605853, 0.00982105,\n", + " -0.05653799, -0.06617444, -0.08152387, 0.09066151, 0.00207551,\n", + " -0.03963645, 0.09282233, 0.02758925, 0.01784172, 0.11217704,\n", + " 0.05094281, 0.08854876, -0.09565834, 0.00443037, -0.01511557,\n", + " 0.10326956, -0.06927156, -0.0166677 , 0.0913672 , 0.06746135,\n", + " -0.04688244],\n", + " [ 0.02260412, 0.00678681, 0.00549161, -0.11994145, 0.04870088,\n", + " -0.05051432, -0.1141186 , 0.06037819, 0.04170217, -0.0586402 ,\n", + " -0.10248884, 0.01742958, -0.01947546, 0.06129252, 0.07150439,\n", + " -0.06523626, 0.09166035, 0.09504693, -0.03253129, -0.06043063,\n", + " -0.0926532 , -0.11705144, 0.0379782 , -0.05661604, -0.11245252,\n", + " -0.1087203 ]])}\n", + "{0: array([[-0.0758517 , -0.10922462, 0.03927368, ..., -0.10116138,\n", + " -0.06562225, 0.08434671],\n", + " [ 0.0605631 , 0.07263726, -0.02723178, ..., -0.11961522,\n", + " 0.11282753, -0.02994306],\n", + " [ 0.11871606, 0.07256746, -0.01404285, ..., 0.10602606,\n", + " 0.10491185, -0.05867995],\n", + " ...,\n", + " [ 0.0919954 , -0.04510364, 0.03600806, ..., 0.11729077,\n", + " -0.05894698, 0.06732849],\n", + " [-0.11695147, 0.10688355, 0.0685417 , ..., -0.03359375,\n", + " -0.0944231 , 0.05190055],\n", + " [-0.01826086, 0.09777763, 0.08879194, ..., 0.0213223 ,\n", + " -0.01198016, 0.08821151]]), 1: array([[-0.0758517 , -0.10922462, 0.03927368, 0.04616654, -0.03550035,\n", + " -0.04301836, -0.05111002, -0.01779515, 0.01221125, 0.01600122,\n", + " -0.07813955, -0.04786295, 0.08427782, 0.09664291, 0.02393845,\n", + " -0.11691394, -0.09485863, 0.10624724, 0.04490283, -0.10627072,\n", + " 0.10434709, 0.09732701, 0.00282358, -0.00734285, -0.09995941,\n", + " -0.11459936],\n", + " [-0.06447097, 0.08244544, 0.1055696 , 0.07628717, 0.07190959,\n", + " -0.03122521, 0.11683788, -0.11039357, -0.06716706, 0.01594479,\n", + " -0.00133989, 0.03575698, 0.05225359, 0.00083552, 0.02347394,\n", + " -0.06132896, 0.05638961, -0.0044371 , -0.02010712, -0.10999934,\n", + " -0.04849349, 0.04404446, 0.07483567, -0.06039278, 0.04542441,\n", + " 0.08904211],\n", + " [-0.04615776, -0.07739299, -0.09182918, -0.02440465, -0.05477409,\n", + " 0.03588698, 0.02053208, -0.04363991, 0.04534481, 0.05309741,\n", + " -0.07172521, -0.01942062, -0.06344989, 0.09863689, -0.11781185,\n", + " 0.02971791, -0.02973962, 0.06304532, -0.07716626, -0.03389946,\n", + " -0.04426616, 0.03890041, -0.07181278, 0.02769418, 0.00820932,\n", + " 0.10949384],\n", + " [-0.0584883 , -0.05273799, -0.04073093, -0.08154635, 0.08999456,\n", + " -0.09110997, 0.09805592, 0.02330922, 0.07835466, -0.11295456,\n", + " -0.05768334, 0.00250513, -0.00909849, -0.00671458, 0.06267393,\n", + " 0.07735554, 0.05565781, -0.06221527, 0.10644233, 0.03333939,\n", + " 0.02334794, -0.01852243, 0.03946706, 0.11171577, 0.00829028,\n", + " -0.05008512],\n", + " [-0.00948572, 0.00763778, 0.07092984, 0.03798784, -0.07694375,\n", + " 0.05564401, 0.11472868, 0.11388296, 0.08657028, -0.01318174,\n", + " 0.02493628, 0.01862749, 0.01416905, -0.10815415, 0.08573075,\n", + " 0.02036101, 0.06934405, 0.11281956, 0.02856743, -0.06820671,\n", + " -0.08479958, 0.02668589, -0.05561203, 0.05716293, 0.11849236,\n", + " 0.05245313],\n", + " [ 0.11395225, -0.07448341, -0.11355455, 0.07997803, -0.02016351,\n", + " 0.02623673, -0.09786482, -0.08886998, -0.02424251, 0.06848556,\n", + " -0.11399175, 0.01630017, 0.00199946, 0.00148151, -0.03053501,\n", + " 0.05940618, -0.05865 , 0.06081712, -0.06157728, -0.11024059,\n", + " -0.0677528 , 0.06100844, 0.02996631, -0.03733193, 0.09442967,\n", + " 0.0271904 ],\n", + " [ 0.03050059, -0.03451764, 0.07158443, 0.0541165 , 0.01873904,\n", + " -0.05535262, 0.00458515, -0.09848468, -0.01277639, -0.10496153,\n", + " -0.06116952, -0.0284652 , 0.0300631 , -0.02659276, 0.09268343,\n", + " -0.08086429, -0.07301074, -0.03411321, 0.1054892 , 0.0424244 ,\n", + " 0.09827251, 0.03980845, -0.09431661, -0.0580831 , -0.04872072,\n", + " 0.106885 ],\n", + " [ 0.08076684, -0.00780408, 0.06917175, 0.10370648, -0.00244977,\n", + " -0.09103661, -0.03319441, -0.10700324, 0.03875014, -0.02056288,\n", + " -0.01949595, -0.05121848, 0.10714613, -0.00404258, 0.0173522 ,\n", + " -0.05759117, -0.08206716, 0.08263817, -0.00864864, -0.08316974,\n", + " 0.08279706, 0.04957311, 0.03934321, 0.05675562, 0.04299622,\n", + " 0.04064601],\n", + " [ 0.00825281, -0.07706374, -0.00922871, 0.05605853, 0.00982105,\n", + " -0.05653799, -0.06617444, -0.08152387, 0.09066151, 0.00207551,\n", + " -0.03963645, 0.09282233, 0.02758925, 0.01784172, 0.11217704,\n", + " 0.05094281, 0.08854876, -0.09565834, 0.00443037, -0.01511557,\n", + " 0.10326956, -0.06927156, -0.0166677 , 0.0913672 , 0.06746135,\n", + " -0.04688244],\n", + " [ 0.02260412, 0.00678681, 0.00549161, -0.11994145, 0.04870088,\n", + " -0.05051432, -0.1141186 , 0.06037819, 0.04170217, -0.0586402 ,\n", + " -0.10248884, 0.01742958, -0.01947546, 0.06129252, 0.07150439,\n", + " -0.06523626, 0.09166035, 0.09504693, -0.03253129, -0.06043063,\n", + " -0.0926532 , -0.11705144, 0.0379782 , -0.05661604, -0.11245252,\n", + " -0.1087203 ]])}\n" ] } ],