|
38 | 38 | [image, target]
|
39 | 39 | end
|
40 | 40 |
|
41 |
| -data.shuffle! |
| 41 | +# data.shuffle! |
42 | 42 |
|
43 | 43 | train_size = (ARGV[0] || 100).to_i
|
44 | 44 | test_size = 100
|
|
62 | 62 | x_test = x_data.slice(train_size, test_size)
|
63 | 63 | y_test = y_data.slice(train_size, test_size)
|
64 | 64 |
|
| 65 | + |
65 | 66 | puts "Initializing network with #{hidden_layer_size} hidden neurons."
|
66 | 67 | nn = NeuralNet.new [28*28,hidden_layer_size, 50, 10]
|
67 | 68 |
|
|
72 | 73 | (errors.inject(0) {|sum, err| sum += err**2}) / errors.length.to_f
|
73 | 74 | }
|
74 | 75 |
|
75 |
| -prediction_success = -> (actual, ideal) { |
76 |
| - predicted = (0..9).max_by{|i| actual[i]} |
77 |
| - ideal = (0..9).max_by{|i| ideal[i]} |
78 |
| - predicted == ideal |
79 |
| -} |
| 76 | +decode_output = -> (output) { (0..9).max_by {|i| output[i]} } |
| 77 | +prediction_success = -> (actual, ideal) { decode_output.(actual) == decode_output.(ideal) } |
80 | 78 |
|
81 | 79 | run_test = -> (nn, inputs, expected_outputs) {
|
82 | 80 | success, failure, errsum = 0,0,0
|
|
96 | 94 |
|
97 | 95 | puts "\nTraining the network with #{train_size} data samples...\n\n"
|
98 | 96 | t = Time.now
|
99 |
| -result = nn.train(x_train, y_train, log_every: 1, iterations: 100) |
| 97 | +result = nn.train(x_train, y_train, log_every: 1, max_iterations: 100) |
100 | 98 |
|
101 | 99 | puts "\nDone training the network: #{result[:iterations]} iterations, error #{result[:error].round(5)}, #{(Time.now - t).round(1)}s"
|
102 | 100 |
|
|
111 | 109 |
|
112 | 110 | puts "Trained prediction success: #{success}, failure: #{failure} (Error rate: #{error_rate.(failure, x_test.length)}%, mse: #{avg_mse.round(5)})"
|
113 | 111 |
|
114 |
| -begin |
115 |
| - require_relative './image_grid' |
116 |
| -rescue LoadError |
117 |
| - puts "\ngem install chunky_png to output visualization of hidden weights" |
118 |
| - exit |
119 |
| -end |
120 | 112 |
|
121 |
| -ImageGrid.new(nn.weights[1]).to_file 'examples/mnist/hidden1_weights.png' |
| 113 | +# require_relative './image_grid' |
| 114 | +# ImageGrid.new(nn.weights[1]).to_file 'examples/mnist/hidden_weights.png' |
0 commit comments