|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 8, |
| 6 | + "metadata": { |
| 7 | + "collapsed": true |
| 8 | + }, |
| 9 | + "outputs": [], |
| 10 | + "source": [ |
| 11 | + "'''\n", |
| 12 | + "linear regression experiment, hope you can know:\n", |
| 13 | + "1. how to design the learning model\n", |
| 14 | + "2. optimize the model\n", |
| 15 | + "3. dealing with the dataset\n", |
| 16 | + "\n", |
| 17 | + "Original Author: Aymeric Damien\n", |
| 18 | + "Edited by Wei Li for ChinaHadoop Deep learning course\n", |
| 19 | + "Project: https://github.com/aymericdamien/TensorFlow-Examples/\n", |
| 20 | + "'''\n", |
| 21 | + "\n", |
| 22 | + "\n", |
| 23 | + "import tensorflow as tf\n", |
| 24 | + "import numpy\n", |
| 25 | + "rng = numpy.random\n", |
| 26 | + "\n", |
| 27 | + "# model params\n", |
| 28 | + "learning_rate = 0.02\n", |
| 29 | + "training_epochs = 3000\n", |
| 30 | + "display_step=50\n", |
| 31 | + "# \n", |
| 32 | + "train_X = numpy.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,\n", |
| 33 | + " 7.042,10.791,5.313,7.997,5.654,9.27,3.1])\n", |
| 34 | + "train_Y = numpy.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,\n", |
| 35 | + " 2.827,3.465,1.65,2.904,2.42,2.94,1.3])\n", |
| 36 | + "n_samples = train_X.shape[0]\n", |
| 37 | + "\n", |
| 38 | + "# tf Graph Input\n", |
| 39 | + "X = tf.placeholder(\"float\")\n", |
| 40 | + "Y = tf.placeholder(\"float\")\n", |
| 41 | + "\n", |
| 42 | + "# Set model weights\n", |
| 43 | + "W = tf.Variable(rng.randn(), name=\"weight\")\n", |
| 44 | + "b = tf.Variable(rng.randn(), name=\"bias\")\n", |
| 45 | + "\n", |
| 46 | + "# Construct a linear model\n", |
| 47 | + "pred = tf.add(tf.multiply(X, W), b)\n", |
| 48 | + "\n", |
| 49 | + "# Mean squared error\n", |
| 50 | + "cost = tf.reduce_sum(tf.pow(pred-Y, 2))/(2*n_samples)\n", |
| 51 | + "# Gradient descent\n", |
| 52 | + "optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)\n", |
| 53 | + "\n", |
| 54 | + "# Initializing the variables\n", |
| 55 | + "init = tf.global_variables_initializer()" |
| 56 | + ] |
| 57 | + }, |
| 58 | + { |
| 59 | + "cell_type": "code", |
| 60 | + "execution_count": 9, |
| 61 | + "metadata": { |
| 62 | + "collapsed": false |
| 63 | + }, |
| 64 | + "outputs": [ |
| 65 | + { |
| 66 | + "name": "stdout", |
| 67 | + "output_type": "stream", |
| 68 | + "text": [ |
| 69 | + "('Epoch:', '0050', 'cost=', '0.178423569', 'W=', 0.42291793, 'b=', -0.4734658)\n", |
| 70 | + "('Epoch:', '0100', 'cost=', '0.156202286', 'W=', 0.40251526, 'b=', -0.32475927)\n", |
| 71 | + "('Epoch:', '0150', 'cost=', '0.138855815', 'W=', 0.38448787, 'b=', -0.19336548)\n", |
| 72 | + "('Epoch:', '0200', 'cost=', '0.125314981', 'W=', 0.36855927, 'b=', -0.077268951)\n", |
| 73 | + "('Epoch:', '0250', 'cost=', '0.114744954', 'W=', 0.35448512, 'b=', 0.025311502)\n", |
| 74 | + "('Epoch:', '0300', 'cost=', '0.106494129', 'W=', 0.34204948, 'b=', 0.11594942)\n", |
| 75 | + "('Epoch:', '0350', 'cost=', '0.100053802', 'W=', 0.3310616, 'b=', 0.19603507)\n", |
| 76 | + "('Epoch:', '0400', 'cost=', '0.095026731', 'W=', 0.32135299, 'b=', 0.26679745)\n", |
| 77 | + "('Epoch:', '0450', 'cost=', '0.091103002', 'W=', 0.31277463, 'b=', 0.3293213)\n", |
| 78 | + "('Epoch:', '0500', 'cost=', '0.088040523', 'W=', 0.30519509, 'b=', 0.38456526)\n", |
| 79 | + "('Epoch:', '0550', 'cost=', '0.085650302', 'W=', 0.29849792, 'b=', 0.43337804)\n", |
| 80 | + "('Epoch:', '0600', 'cost=', '0.083784848', 'W=', 0.29258049, 'b=', 0.47650799)\n", |
| 81 | + "('Epoch:', '0650', 'cost=', '0.082329050', 'W=', 0.28735185, 'b=', 0.51461679)\n", |
| 82 | + "('Epoch:', '0700', 'cost=', '0.081192940', 'W=', 0.28273201, 'b=', 0.54828918)\n", |
| 83 | + "('Epoch:', '0750', 'cost=', '0.080306433', 'W=', 0.27865005, 'b=', 0.57804072)\n", |
| 84 | + "('Epoch:', '0800', 'cost=', '0.079614699', 'W=', 0.27504328, 'b=', 0.60432887)\n", |
| 85 | + "('Epoch:', '0850', 'cost=', '0.079074971', 'W=', 0.27185646, 'b=', 0.6275565)\n", |
| 86 | + "('Epoch:', '0900', 'cost=', '0.078653932', 'W=', 0.26904064, 'b=', 0.64807951)\n", |
| 87 | + "('Epoch:', '0950', 'cost=', '0.078325450', 'W=', 0.26655263, 'b=', 0.66621393)\n", |
| 88 | + "('Epoch:', '1000', 'cost=', '0.078069247', 'W=', 0.26435426, 'b=', 0.68223649)\n", |
| 89 | + "('Epoch:', '1050', 'cost=', '0.077869445', 'W=', 0.26241186, 'b=', 0.69639373)\n", |
| 90 | + "('Epoch:', '1100', 'cost=', '0.077713616', 'W=', 0.26069549, 'b=', 0.70890343)\n", |
| 91 | + "('Epoch:', '1150', 'cost=', '0.077592134', 'W=', 0.25917912, 'b=', 0.71995574)\n", |
| 92 | + "('Epoch:', '1200', 'cost=', '0.077497423', 'W=', 0.2578392, 'b=', 0.72972184)\n", |
| 93 | + "('Epoch:', '1250', 'cost=', '0.077423617', 'W=', 0.25665545, 'b=', 0.73834991)\n", |
| 94 | + "('Epoch:', '1300', 'cost=', '0.077366099', 'W=', 0.2556093, 'b=', 0.74597466)\n", |
| 95 | + "('Epoch:', '1350', 'cost=', '0.077321291', 'W=', 0.25468507, 'b=', 0.75271124)\n", |
| 96 | + "('Epoch:', '1400', 'cost=', '0.077286400', 'W=', 0.25386831, 'b=', 0.75866407)\n", |
| 97 | + "('Epoch:', '1450', 'cost=', '0.077259235', 'W=', 0.25314665, 'b=', 0.76392406)\n", |
| 98 | + "('Epoch:', '1500', 'cost=', '0.077238098', 'W=', 0.252509, 'b=', 0.76857102)\n", |
| 99 | + "('Epoch:', '1550', 'cost=', '0.077221632', 'W=', 0.25194564, 'b=', 0.77267736)\n", |
| 100 | + "('Epoch:', '1600', 'cost=', '0.077208854', 'W=', 0.25144795, 'b=', 0.77630514)\n", |
| 101 | + "('Epoch:', '1650', 'cost=', '0.077198923', 'W=', 0.25100803, 'b=', 0.77951139)\n", |
| 102 | + "('Epoch:', '1700', 'cost=', '0.077191189', 'W=', 0.25061971, 'b=', 0.78234196)\n", |
| 103 | + "('Epoch:', '1750', 'cost=', '0.077185199', 'W=', 0.25027612, 'b=', 0.78484607)\n", |
| 104 | + "('Epoch:', '1800', 'cost=', '0.077180564', 'W=', 0.24997255, 'b=', 0.78705853)\n", |
| 105 | + "('Epoch:', '1850', 'cost=', '0.077176966', 'W=', 0.2497045, 'b=', 0.78901207)\n", |
| 106 | + "('Epoch:', '1900', 'cost=', '0.077174187', 'W=', 0.24946776, 'b=', 0.79073763)\n", |
| 107 | + "('Epoch:', '1950', 'cost=', '0.077172041', 'W=', 0.24925858, 'b=', 0.79226238)\n", |
| 108 | + "('Epoch:', '2000', 'cost=', '0.077170387', 'W=', 0.24907368, 'b=', 0.7936098)\n", |
| 109 | + "('Epoch:', '2050', 'cost=', '0.077169113', 'W=', 0.24891038, 'b=', 0.79480028)\n", |
| 110 | + "('Epoch:', '2100', 'cost=', '0.077168114', 'W=', 0.24876596, 'b=', 0.79585338)\n", |
| 111 | + "('Epoch:', '2150', 'cost=', '0.077167362', 'W=', 0.24863829, 'b=', 0.79678357)\n", |
| 112 | + "('Epoch:', '2200', 'cost=', '0.077166796', 'W=', 0.24852541, 'b=', 0.79760629)\n", |
| 113 | + "('Epoch:', '2250', 'cost=', '0.077166334', 'W=', 0.24842578, 'b=', 0.79833227)\n", |
| 114 | + "('Epoch:', '2300', 'cost=', '0.077165999', 'W=', 0.2483376, 'b=', 0.79897529)\n", |
| 115 | + "('Epoch:', '2350', 'cost=', '0.077165760', 'W=', 0.24825987, 'b=', 0.79954147)\n", |
| 116 | + "('Epoch:', '2400', 'cost=', '0.077165581', 'W=', 0.24819092, 'b=', 0.80004394)\n", |
| 117 | + "('Epoch:', '2450', 'cost=', '0.077165432', 'W=', 0.24813022, 'b=', 0.80048668)\n", |
| 118 | + "('Epoch:', '2500', 'cost=', '0.077165321', 'W=', 0.24807698, 'b=', 0.80087441)\n", |
| 119 | + "('Epoch:', '2550', 'cost=', '0.077165253', 'W=', 0.24802969, 'b=', 0.80121905)\n", |
| 120 | + "('Epoch:', '2600', 'cost=', '0.077165186', 'W=', 0.24798796, 'b=', 0.80152339)\n", |
| 121 | + "('Epoch:', '2650', 'cost=', '0.077165157', 'W=', 0.2479513, 'b=', 0.8017906)\n", |
| 122 | + "('Epoch:', '2700', 'cost=', '0.077165119', 'W=', 0.24791868, 'b=', 0.80202842)\n", |
| 123 | + "('Epoch:', '2750', 'cost=', '0.077165097', 'W=', 0.24789007, 'b=', 0.80223686)\n", |
| 124 | + "('Epoch:', '2800', 'cost=', '0.077165097', 'W=', 0.24786451, 'b=', 0.80242288)\n", |
| 125 | + "('Epoch:', '2850', 'cost=', '0.077165082', 'W=', 0.24784194, 'b=', 0.80258781)\n", |
| 126 | + "('Epoch:', '2900', 'cost=', '0.077165082', 'W=', 0.24782193, 'b=', 0.80273348)\n", |
| 127 | + "('Epoch:', '2950', 'cost=', '0.077165082', 'W=', 0.24780463, 'b=', 0.80285954)\n", |
| 128 | + "('Epoch:', '3000', 'cost=', '0.077165082', 'W=', 0.24778947, 'b=', 0.80296975)\n", |
| 129 | + "('Training cost=', 0.077165082, 'W=', 0.24778947, 'b=', 0.80296975, '\\n')\n", |
| 130 | + "Tssting...\n", |
| 131 | + "('Test LOSS=', 0.079976395)\n", |
| 132 | + "('Final Loss:', 0.0028113127)\n" |
| 133 | + ] |
| 134 | + } |
| 135 | + ], |
| 136 | + "source": [ |
| 137 | + "\n", |
| 138 | + "# Launch the graph\n", |
| 139 | + "with tf.Session() as sess:\n", |
| 140 | + " sess.run(init)\n", |
| 141 | + "\n", |
| 142 | + " # Fit all training data\n", |
| 143 | + " for epoch in range(training_epochs):\n", |
| 144 | + " for (x, y) in zip(train_X, train_Y):\n", |
| 145 | + " sess.run(optimizer, feed_dict={X: x, Y: y})\n", |
| 146 | + "\n", |
| 147 | + " # Display logs per epoch step\n", |
| 148 | + " if (epoch+1) % display_step == 0:\n", |
| 149 | + " c = sess.run(cost, feed_dict={X: train_X, Y:train_Y})\n", |
| 150 | + " print(\"Epoch:\", '%04d' % (epoch+1), \"cost=\", \"{:.9f}\".format(c), \\\n", |
| 151 | + " \"W=\", sess.run(W), \"b=\", sess.run(b))\n", |
| 152 | + "\n", |
| 153 | + "\n", |
| 154 | + " training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})\n", |
| 155 | + " print(\"Training cost=\", training_cost, \"W=\", sess.run(W), \"b=\", sess.run(b), '\\n')\n", |
| 156 | + "\n", |
| 157 | + " \n", |
| 158 | + "\n", |
| 159 | + " # the testing data\n", |
| 160 | + " test_X = numpy.asarray([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1])\n", |
| 161 | + " test_Y = numpy.asarray([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03])\n", |
| 162 | + "\n", |
| 163 | + " print(\"Tssting...\")\n", |
| 164 | + " testing_cost = sess.run(\n", |
| 165 | + " tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * test_X.shape[0]),\n", |
| 166 | + " feed_dict={X: test_X, Y: test_Y}) # same function as cost above\n", |
| 167 | + " print(\"Test LOSS=\", testing_cost)\n", |
| 168 | + " print(\"Final Loss:\", abs(\n", |
| 169 | + " training_cost - testing_cost))" |
| 170 | + ] |
| 171 | + }, |
| 172 | + { |
| 173 | + "cell_type": "code", |
| 174 | + "execution_count": null, |
| 175 | + "metadata": { |
| 176 | + "collapsed": true |
| 177 | + }, |
| 178 | + "outputs": [], |
| 179 | + "source": [] |
| 180 | + } |
| 181 | + ], |
| 182 | + "metadata": { |
| 183 | + "kernelspec": { |
| 184 | + "display_name": "Python 2", |
| 185 | + "language": "python", |
| 186 | + "name": "python2" |
| 187 | + }, |
| 188 | + "language_info": { |
| 189 | + "codemirror_mode": { |
| 190 | + "name": "ipython", |
| 191 | + "version": 2 |
| 192 | + }, |
| 193 | + "file_extension": ".py", |
| 194 | + "mimetype": "text/x-python", |
| 195 | + "name": "python", |
| 196 | + "nbconvert_exporter": "python", |
| 197 | + "pygments_lexer": "ipython2", |
| 198 | + "version": "2.7.12" |
| 199 | + } |
| 200 | + }, |
| 201 | + "nbformat": 4, |
| 202 | + "nbformat_minor": 2 |
| 203 | +} |
0 commit comments