Skip to content

Commit cbb73c8

Browse files
committed
Enable marshal dump and load
1 parent 9b5b6b9 commit cbb73c8

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

mnist.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777

7878
puts "\nDone training the network: #{result[:iterations]} iterations, error #{result[:error].round(5)}, #{(Time.now - t).round(1)}s"
7979

80+
# # Marshal test
81+
# dumpfile = 'mnist/network.dump'
82+
# File.write(dumpfile, Marshal.dump(nn))
83+
# nn = Marshal.load(File.read(dumpfile))
8084

8185
puts "\nTesting the trained network..."
8286

neural_net.rb

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@ class NeuralNet
77
error_threshold: 0.01
88
}
99

10-
def initialize(shape)
10+
def initialize shape
1111
@shape = shape
12-
@output_layer = @shape.length - 1
13-
set_initial_weight_values
1412
end
1513

1614
def run input
1715
# Input to this method represents the output of the first layer (i.e., the input layer)
1816
@outputs = [input]
17+
set_initial_weight_values if @weights.nil?
1918

2019
# Now calculate output of neurons in subsequent layers:
21-
1.upto(@output_layer).each do |layer|
20+
1.upto(output_layer).each do |layer|
2221
source_layer = layer - 1 # i.e, the layer that is feeding into this one
2322
source_outputs = @outputs[source_layer]
2423

@@ -36,7 +35,7 @@ def run input
3635
end
3736

3837
# Outputs of neurons in the last layer is the final result
39-
@outputs[@output_layer]
38+
@outputs[output_layer]
4039
end
4140

4241
def train data, opts = {}
@@ -45,8 +44,9 @@ def train data, opts = {}
4544
iteration = 0
4645
error = nil
4746

48-
set_weight_changes_to_zeros
47+
set_initial_weight_values if @weights.nil?
4948
set_initial_weight_update_values if @weight_update_values.nil?
49+
set_weight_changes_to_zeros
5050
set_previous_gradients_to_zeroes
5151

5252
while iteration < opts[:max_iterations]
@@ -84,15 +84,15 @@ def train_on_batch data
8484
end
8585

8686
def calculate_training_error ideal_output
87-
@outputs[@output_layer].map.with_index do |output, i|
87+
@outputs[output_layer].map.with_index do |output, i|
8888
output - ideal_output[i]
8989
end
9090
end
9191

9292
def update_gradients training_error
9393
deltas = []
9494
# Starting from output layer and working backwards, backpropagating the training error
95-
@output_layer.downto(1).each do |layer|
95+
output_layer.downto(1).each do |layer|
9696
deltas[layer] = []
9797
source_layer = layer - 1
9898
source_neurons = @shape[source_layer] + 1 # account for bias neuron
@@ -103,7 +103,7 @@ def update_gradients training_error
103103
activation_derivative = output * (1.0 - output)
104104

105105
# calculate delta for neuron
106-
delta = deltas[layer][neuron] = if layer == @output_layer
106+
delta = deltas[layer][neuron] = if layer == output_layer
107107
# For neurons in output layer, use training error
108108
-training_error[neuron] * activation_derivative
109109
else
@@ -132,7 +132,7 @@ def update_gradients training_error
132132
# Now that we've calculated gradients for the batch, we can use these to update the weights
133133
# Using the RPROP algorithm - somewhat more complicated than classic backpropagation algorithm, but much faster
134134
def update_weights
135-
1.upto(@output_layer) do |layer|
135+
1.upto(output_layer) do |layer|
136136
source_layer = layer - 1
137137
source_neurons = @shape[source_layer] + 1 # account for bias neuron
138138

@@ -213,6 +213,10 @@ def build_matrix
213213
end
214214
end
215215

216+
def output_layer
217+
@shape.length - 1
218+
end
219+
216220
def sigmoid x
217221
1 / (1 + Math::E**-x)
218222
end
@@ -230,4 +234,12 @@ def sign x
230234
x <=> 0 # returns 1 if postitive, -1 if negative
231235
end
232236
end
237+
238+
def marshal_dump
239+
[@shape, @weights, @weight_update_values]
240+
end
241+
242+
def marshal_load array
243+
@shape, @weights, @weight_update_values = array
244+
end
233245
end

0 commit comments

Comments
 (0)