Skip to content

Commit 072550f

Browse files
committed
cleanup
1 parent 4d98b7a commit 072550f

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

examples/image_grid.rb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def generate_png_from_data
2222

2323
maxes, mins = [], []
2424
@data.length.times do |img|
25-
maxes << @data[img].max
26-
mins << @data[img].min
25+
imgdata = @data[img].slice(0, 28*28)
26+
maxes << imgdata.max
27+
mins << imgdata.min
2728
end
2829
max = maxes.max
2930
min = mins.min
@@ -61,4 +62,4 @@ def generate_png_from_data
6162

6263
# require_relative '../neural_net'
6364
# nn = Marshal.load(File.read('examples/mnist/network.dump'))
64-
# ImageGrid.new(nn.weights[1]).to_file 'test.png'
65+
# ImageGrid.new(nn.weights[1]).to_file 'examples/mnist/test.png'

examples/mnist.rb

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
[image, target]
3939
end
4040

41-
data.shuffle!
41+
# data.shuffle!
4242

4343
train_size = (ARGV[0] || 100).to_i
4444
test_size = 100
@@ -62,6 +62,7 @@
6262
x_test = x_data.slice(train_size, test_size)
6363
y_test = y_data.slice(train_size, test_size)
6464

65+
6566
puts "Initializing network with #{hidden_layer_size} hidden neurons."
6667
nn = NeuralNet.new [28*28,hidden_layer_size, 50, 10]
6768

@@ -72,11 +73,8 @@
7273
(errors.inject(0) {|sum, err| sum += err**2}) / errors.length.to_f
7374
}
7475

75-
prediction_success = -> (actual, ideal) {
76-
predicted = (0..9).max_by{|i| actual[i]}
77-
ideal = (0..9).max_by{|i| ideal[i]}
78-
predicted == ideal
79-
}
76+
decode_output = -> (output) { (0..9).max_by {|i| output[i]} }
77+
prediction_success = -> (actual, ideal) { decode_output.(actual) == decode_output.(ideal) }
8078

8179
run_test = -> (nn, inputs, expected_outputs) {
8280
success, failure, errsum = 0,0,0
@@ -96,7 +94,7 @@
9694

9795
puts "\nTraining the network with #{train_size} data samples...\n\n"
9896
t = Time.now
99-
result = nn.train(x_train, y_train, log_every: 1, iterations: 100)
97+
result = nn.train(x_train, y_train, log_every: 1, max_iterations: 100)
10098

10199
puts "\nDone training the network: #{result[:iterations]} iterations, error #{result[:error].round(5)}, #{(Time.now - t).round(1)}s"
102100

@@ -111,11 +109,6 @@
111109

112110
puts "Trained prediction success: #{success}, failure: #{failure} (Error rate: #{error_rate.(failure, x_test.length)}%, mse: #{avg_mse.round(5)})"
113111

114-
begin
115-
require_relative './image_grid'
116-
rescue LoadError
117-
puts "\ngem install chunky_png to output visualization of hidden weights"
118-
exit
119-
end
120112

121-
ImageGrid.new(nn.weights[1]).to_file 'examples/mnist/hidden1_weights.png'
113+
# require_relative './image_grid'
114+
# ImageGrid.new(nn.weights[1]).to_file 'examples/mnist/hidden_weights.png'

0 commit comments

Comments
 (0)