Skip to content

Commit 5d8ce7e

Browse files
committed
Updated Learning Rule
1 parent 37271a1 commit 5d8ce7e

File tree

2 files changed

+48
-39
lines changed

2 files changed

+48
-39
lines changed

and_net.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@ def __init__(self, neurons, connections):
1010
self.neurons = [] # A list of lif_neuron objects
1111
self.connections = [] # A list of lif_connection objects
1212

13+
run_time = 1000
14+
1315
# Function calculates current from an array of spikes from a neuron's output activity
1416
def spikes_to_current(self, a_spikes, conversion_factor):
1517

16-
spike_count = 0
17-
for i in range(len(a_spikes)):
18-
if a_spikes[i] == 1:
19-
spike_count += 1
20-
21-
spikes_per_second = spike_count / len(a_spikes)
22-
23-
return spikes_per_second * conversion_factor * 100000
18+
return sum(a_spikes) / len(a_spikes) * conversion_factor * 50000
2419

2520
# Configure the network to solve problem 3
2621
def create_simple_net(self):
@@ -47,17 +42,17 @@ def spikes_to_binary(self, output_neuron, threshold=4):
4742
if output_neuron.output[i] == 1:
4843
spike_count += 1
4944

50-
spike_frequency = spike_count / len(output_neuron.output)
45+
spike_frequency = spike_count
5146

5247
if (spike_frequency > threshold):
5348
return 1
5449

5550
return 0
5651

57-
# Converts a single bit into 1000 ms array of spikes
52+
# Converts a single bit into an array of spikes
5853
# A value of 0 makes 5Hz spikes
5954
# A value of 1 makes 15Hz spikes
60-
def binary_to_spikes(self, val, nsteps=1000):
55+
def binary_to_spikes(self, val, nsteps=run_time):
6156

6257
a_spikes = [0] * nsteps
6358
nspikes = 5
@@ -66,33 +61,36 @@ def binary_to_spikes(self, val, nsteps=1000):
6661
nspikes = 15
6762

6863
interval = int(nsteps / nspikes)
69-
for i in range(1000):
64+
for i in range(nsteps):
7065
if (i % interval == 0):
7166
a_spikes[i] = 1
7267

7368
return a_spikes
7469

7570
# Runs the neural net
7671
# Maybe include option to run without training neuron
77-
def run_net(self, input):
72+
def run_net(self, input, time = run_time):
7873
# Initializes a current for the three input neurons
7974
xIn = self.spikes_to_current(self.binary_to_spikes(input[0]), self.connections[0].weight)
8075
yIn = self.spikes_to_current(self.binary_to_spikes(input[1]), self.connections[1].weight)
8176
# Change and to xor/or
82-
teachIn = self.spikes_to_current(self.binary_to_spikes(input[1] and input[0]), self.connections[2].weight)
83-
self.neurons[0].sim(xIn, 1000)
84-
self.neurons[1].sim(yIn, 1000)
85-
self.neurons[2].sim(teachIn, 1000)
77+
#teachIn = self.spikes_to_current(self.binary_to_spikes(input[1] and input[0]), self.connections[2].weight)
78+
self.neurons[0].sim(xIn, time)
79+
self.neurons[1].sim(yIn, time)
80+
#self.neurons[2].sim(teachIn, time)
8681
# Calculates the input current for the output by using an or function on all 3 spike outputs
8782
outIn = []
88-
for i in range(1000):
89-
outIn.append(self.neurons[0].output[i] or self.neurons[1].output[i] or self.neurons[2].output[i])
90-
self.neurons[3].sim(self.spikes_to_current(outIn, 1), 1000)
83+
for i in range(time):
84+
if(self.neurons[0].output[i] == 1 or self.neurons[1].output[i] == 1):
85+
outIn.append(1)
86+
else:
87+
outIn.append(0)
88+
#Conversion factor scales # of spikes from output
89+
self.neurons[3].sim(self.spikes_to_current(outIn, .5), time)
9190
output = 0
92-
for i in range(1000):
91+
for i in range(time):
9392
output = output + self.neurons[3].output[i]
9493
if output > 5:
9594
return 1
9695
else:
97-
return 0
98-
96+
return 0

training.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from and_net import and_net
2-
import matplotlib.pyplot as plt
2+
#import matplotlib.pyplot as plt
33

44
net = and_net([], [])
55
net.create_simple_net()
66
# X at [0], Y at [1], Teacher at [2]
7-
net.connections[0].weight = 1
8-
net.connections[1].weight = 1
7+
net.connections[0].weight = .6
8+
net.connections[1].weight = .4
99
net.connections[2].weight = 1
1010
# data[0] holds X, data[1] holds Y, output calculated by data[0][i] && data[1][i]
1111
data = []
1212
i = 0
1313
j = 0
1414
k = 0
1515

16-
16+
"""
1717
def raster_plot(net):
1818
plt.figure()
1919
plt.title('Raster plot for network activity')
@@ -46,7 +46,7 @@ def raster_plot(net):
4646
plt.axvline(k)
4747
4848
plt.show()
49-
49+
"""
5050

5151
#lines = open(input("Enter Filepath: "), "r").readlines()
5252
lines = open("data.txt", "r").readlines()
@@ -63,21 +63,32 @@ def raster_plot(net):
6363
# Trains the network
6464
for r in range(runs):
6565
net.run_net(data[r])
66-
#TODO Find better formula
67-
for i in range(1000):
68-
for j in range(2):
69-
if net.neurons[j].output[i] == net.neurons[3].output[i] == 1:
70-
#increase
71-
net.connections[j].weight += .1
72-
elif net.neurons[j].output[i] == 1 and net.neurons[3].output[i] == 0:
73-
#decrease
74-
if(net.connections[j].weight != 0):
75-
net.connections[j].weight -= .1
66+
time = len(net.neurons[0].output)
67+
w_max = 1
68+
for i in range(int(time/100)-1):
69+
#Finds rate for each neuron in a time step of 100
70+
v_x = sum(net.neurons[0].output[100*i:100*(i+1)])
71+
v_y = sum(net.neurons[1].output[100*i:100*(i+1)])
72+
v_o = sum(net.neurons[3].output[100*i:100*(i+1)])
73+
#Sets a_corr value
74+
a_corr_x = w_max - net.connections[0].weight
75+
a_corr_y = w_max - net.connections[1].weight
76+
#If the rate is below 5, set value to 0
77+
if v_x <= 5:
78+
v_x = 0
79+
if v_y <= 5:
80+
v_y = 0
81+
if v_y <= 5:
82+
v_y = 0
83+
#Hebb with postsynaptic LTP/LTD threshold
84+
#TODO change .01 to something else?
85+
net.connections[0].weight += a_corr_x * v_x * (v_o - .01)
86+
net.connections[1].weight += a_corr_y * v_y * (v_o - .01)
7687
print((data[r][0] and data[r][1]), "||", net.spikes_to_binary(net.neurons[3]), "\n")
7788
x_spikes = sum(net.neurons[0].output)
7889
y_spikes = sum(net.neurons[1].output)
7990
o_spikes = sum(net.neurons[3].output)
8091
print("X =", net.connections[0].weight, x_spikes, "\nY =", net.connections[1].weight, y_spikes, "\n", o_spikes)
8192
for i in range(3):
8293
net.neurons[i].clear
83-
raster_plot(net)
94+
#raster_plot(net)

0 commit comments

Comments
 (0)