Skip to content
This repository was archived by the owner on Dec 11, 2018. It is now read-only.

Commit 2098bad

Browse files
committed
neural network works
1 parent a5c27af commit 2098bad

7 files changed

+252
-17
lines changed

config.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"numInputs": 3,
33
"numOutputs": 1,
4-
"numGenerations": 50,
5-
"populationSize": 50,
4+
"numGenerations": 1,
5+
"populationSize": 10,
66
"initFitness": 9999.0,
77
"minimizeFitness": true,
88
"ratePerturb": 0.3,

config_template.json

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"numInputs": 0,
3+
"numOutputs": 0,
4+
"numGenerations": 0,
5+
"populationSize": 0,
6+
"initFitness": 0.0,
7+
"minimizeFitness": false,
8+
"ratePerturb": 0.0,
9+
"rateAddNode": 0.0,
10+
"rateAddConn": 0.0,
11+
"distanceThreshold": 0.0,
12+
"coeffUnmatching": 0.0,
13+
"coeffMatching": 0.0
14+
}

genome.go

+16-4
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ func NewConnGene(from, to *NodeGene, weight float64) *ConnGene {
4444

4545
// String returns the string representation of this connection.
4646
func (c *ConnGene) String() string {
47-
connectivity := fmt.Sprintf("%.3f", c.Weight)
47+
connectivity := fmt.Sprintf("{%.3f}", c.Weight)
4848
if c.Disabled {
49-
connectivity = "/"
49+
connectivity = " / "
5050
}
51-
return fmt.Sprintf("%s--%s--%s", c.From.String(), connectivity, c.To.String())
51+
return fmt.Sprintf("%s-%s->%s", c.From.String(), connectivity, c.To.String())
5252
}
5353

5454
// Genome encodes the weights and topology of the output network as a collection
@@ -57,6 +57,8 @@ type Genome struct {
5757
ID int // genome ID
5858
NodeGenes []*NodeGene // nodes in the genome
5959
ConnGenes []*ConnGene // connections in the genome
60+
Fitness float64 // fitness score
61+
evaluated bool // true if already evaluated
6062
}
6163

6264
// NewGenome returns an instance of initial Genome with fully connected input
@@ -83,13 +85,23 @@ func NewGenome(id, numInputs, numOutputs int) *Genome {
8385

8486
// String returns the string representation of the genome.
8587
func (g *Genome) String() string {
86-
str := fmt.Sprintf("Genome(%d):\n", g.ID)
88+
str := fmt.Sprintf("Genome(%d, %.3f):\n", g.ID, g.Fitness)
8789
for _, conn := range g.ConnGenes {
8890
str += conn.String() + "\n"
8991
}
9092
return str[:len(str)-1]
9193
}
9294

95+
// Evaluate takes an evaluation function and evaluates its fitness. Only perform
96+
// the evaluation if it hasn't yet.
97+
func (g *Genome) Evaluate(eval EvaluationFunc) {
98+
if g.evaluated {
99+
return
100+
}
101+
g.Fitness = eval(NewNeuralNetwork(g))
102+
g.evaluated = true
103+
}
104+
93105
// Mutate mutates the genome in three ways, by perturbing each connection's
94106
// weight, by adding a node between two connected nodes, and by adding a
95107
// connection between two nodes that are not connected.

neat.go

+35-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import (
44
"encoding/json"
55
"fmt"
66
"os"
7+
"runtime"
8+
"sync"
79
"text/tabwriter"
10+
"time"
811
)
912

1013
// Config consists of all hyperparameter settings for NEAT. It can be imported
@@ -86,21 +89,20 @@ func (c *Config) Summarize() {
8689
type NEAT struct {
8790
nextGenomeID int // genome ID that is assigned to a newly created genome
8891

89-
Config *Config // configuration
90-
Population map[*Genome]float64 // population of genome
91-
Species []*Species // subpopulations of genomes grouped by species
92-
Evaluation EvaluationFunc // evaluation function
93-
Best *Genome // best performing genome
92+
Config *Config // configuration
93+
Population []*Genome // population of genome
94+
Species []*Species // subpopulations of genomes grouped by species
95+
Evaluation EvaluationFunc // evaluation function
96+
Best *Genome // best performing genome
9497
}
9598

9699
// New creates a new instance of NEAT with provided argument configuration and
97100
// an evaluation function.
98101
func New(config *Config, evaluation EvaluationFunc) *NEAT {
99102
nextGenomeID := 0
100-
population := make(map[*Genome]float64)
103+
population := make([]*Genome, config.PopulationSize)
101104
for i := 0; i < config.PopulationSize; i++ {
102-
g := NewGenome(nextGenomeID, config.NumInputs, config.NumOutputs)
103-
population[g] = config.InitFitness
105+
population[i] = NewGenome(nextGenomeID, config.NumInputs, config.NumOutputs)
104106
nextGenomeID++
105107
}
106108
return &NEAT{
@@ -112,7 +114,31 @@ func New(config *Config, evaluation EvaluationFunc) *NEAT {
112114
}
113115
}
114116

115-
// Run
117+
// evaluateParallel evaluates all genomes in the population in parallel.
118+
func (n *NEAT) evaluateParallel() {
119+
runtime.GOMAXPROCS(n.Config.PopulationSize)
120+
121+
var wg sync.WaitGroup
122+
wg.Add(n.Config.PopulationSize)
123+
124+
for _, genome := range n.Population {
125+
go func(genome *Genome, evalfn EvaluationFunc) {
126+
defer wg.Done()
127+
genome.Evaluate(evalfn)
128+
}(genome, n.Evaluation)
129+
time.Sleep(time.Millisecond)
130+
}
131+
132+
wg.Wait()
133+
}
134+
135+
// Run executes evolution.
116136
func (n *NEAT) Run() {
137+
for i := 0; i < n.Config.NumGenerations; i++ {
138+
n.evaluateParallel()
117139

140+
for _, genome := range n.Population {
141+
fmt.Println("Genome", genome.ID, "fitness:", genome.Fitness)
142+
}
143+
}
118144
}

neat_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ func NEATUnitTest() {
1717
config.Summarize()
1818

1919
fmt.Println("\x1b[32m=Testing creating and running NEAT...\x1b[0m")
20-
New(config, func(*NeuralNetwork) float64 {
21-
return 1.0
20+
New(config, func(n *NeuralNetwork) float64 {
21+
return rand.Float64()
2222
}).Run()
2323
}
2424

neural_network.go

+134
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,140 @@
11
package neat
22

3+
import (
4+
"fmt"
5+
"sort"
6+
)
7+
8+
// Neuron is an implementation of a single neuron of a neural network.
9+
type Neuron struct {
10+
ID int // neuron ID
11+
Type string // neuron type
12+
Activated bool // true if it has been activated
13+
Signal float64 // signal held by this neuron
14+
Synapses map[*Neuron]float64 // synapse from input neurons
15+
Activation *ActivationFunc // activation function
16+
}
17+
18+
// NewNeuron returns a new instance of neuron, given a node gene.
19+
func NewNeuron(nodeGene *NodeGene) *Neuron {
20+
return &Neuron{
21+
ID: nodeGene.ID,
22+
Type: nodeGene.Type,
23+
Activated: false,
24+
Signal: 0.0,
25+
Synapses: make(map[*Neuron]float64),
26+
Activation: nodeGene.Activation,
27+
}
28+
}
29+
30+
// String returns the string representation of Neuron.
31+
func (n *Neuron) String() string {
32+
if len(n.Synapses) == 0 {
33+
return fmt.Sprintf("[%s(%d, %s)]", n.Type, n.ID, n.Activation.Name)
34+
}
35+
str := fmt.Sprintf("[%s(%d, %s)] (\n", n.Type, n.ID, n.Activation.Name)
36+
for neuron, weight := range n.Synapses {
37+
str += fmt.Sprintf(" <--{%.3f}--[%s(%d, %s)]\n",
38+
weight, neuron.Type, neuron.ID, neuron.Activation.Name)
39+
}
40+
return str + ")"
41+
}
42+
43+
// Activate retrieves signal from neurons that are connected to this neuron and
44+
// return its signal.
45+
func (n *Neuron) Activate() float64 {
46+
// if the neuron's already activated, or it isn't connected from any neurons,
47+
// return its current signal.
48+
if n.Activated || len(n.Synapses) == 0 {
49+
return n.Signal
50+
}
51+
n.Activated = true
52+
53+
inputSum := 0.0
54+
for neuron, weight := range n.Synapses {
55+
inputSum += neuron.Activate() * weight
56+
}
57+
n.Signal = n.Activation.Fn(inputSum)
58+
return n.Signal
59+
}
60+
361
// NeuralNetwork is an implementation of the phenotype neural network that is
462
// decoded from a genome.
563
type NeuralNetwork struct {
64+
NumInputs int // number of inputs
65+
NumOutputs int // number of outputs
66+
Neurons []*Neuron // neurons in the neural network
67+
}
68+
69+
// NewNeuralNetwork returns a new instance of NeuralNetwork given a genome to
70+
// decode from.
71+
func NewNeuralNetwork(g *Genome) *NeuralNetwork {
72+
sort.Slice(g.NodeGenes, func(i, j int) bool {
73+
return g.NodeGenes[i].ID < g.NodeGenes[j].ID
74+
})
75+
76+
numInputs := 0
77+
numOutputs := 0
78+
79+
neurons := make([]*Neuron, 0, len(g.NodeGenes))
80+
for _, nodeGene := range g.NodeGenes {
81+
if nodeGene.Type == "input" {
82+
numInputs++
83+
} else if nodeGene.Type == "output" {
84+
numOutputs++
85+
}
86+
neurons = append(neurons, NewNeuron(nodeGene))
87+
}
88+
89+
for _, connGene := range g.ConnGenes {
90+
if !connGene.Disabled {
91+
if in := sort.Search(len(neurons), func(i int) bool {
92+
return neurons[i].ID >= connGene.From.ID
93+
}); in < len(neurons) && neurons[in].ID == connGene.From.ID {
94+
if out := sort.Search(len(neurons), func(i int) bool {
95+
return neurons[i].ID >= connGene.To.ID
96+
}); out < len(neurons) && neurons[out].ID == connGene.To.ID {
97+
neurons[out].Synapses[neurons[in]] = connGene.Weight
98+
}
99+
}
100+
}
101+
}
102+
return &NeuralNetwork{numInputs, numOutputs, neurons}
103+
}
104+
105+
// String returns the string representation of NeuralNetwork.
106+
func (n *NeuralNetwork) String() string {
107+
str := fmt.Sprintf("NeuralNetwork(%d, %d):\n", n.NumInputs, n.NumOutputs)
108+
for _, neuron := range n.Neurons {
109+
str += neuron.String() + "\n"
110+
}
111+
return str[:len(str)-1]
112+
}
113+
114+
// Feedforward propagates inputs signals from input neurons to output neurons,
115+
// and return output signals.
116+
func (n *NeuralNetwork) FeedForward(inputs []float64) ([]float64, error) {
117+
if len(inputs) != n.NumInputs {
118+
errStr := "Invalid number of inputs: %d != %d"
119+
return nil, fmt.Errorf(errStr, n.NumInputs, len(inputs))
120+
}
121+
122+
// register sensor inputs
123+
for i := 0; i < n.NumInputs; i++ {
124+
n.Neurons[i].Signal = inputs[i]
125+
}
126+
127+
// recursively propagate from input neurons to output neurons
128+
outputs := make([]float64, 0, n.NumOutputs)
129+
for i := n.NumInputs; i < n.NumInputs+n.NumOutputs; i++ {
130+
outputs = append(outputs, n.Neurons[i].Activate())
131+
}
132+
133+
// reset all neurons
134+
for _, neuron := range n.Neurons {
135+
neuron.Activated = false
136+
neuron.Signal = 0.0
137+
}
138+
139+
return outputs, nil
6140
}

neural_network_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package neat
2+
3+
import (
4+
"fmt"
5+
"math/rand"
6+
"testing"
7+
)
8+
9+
func NeuralNetworkUnitTest() {
10+
fmt.Println("===== Neural Network Unit Test =====")
11+
12+
g0 := NewGenome(0, 3, 1)
13+
Mutate(g0, 1.0, 1.0, 1.0)
14+
Mutate(g0, 1.0, 1.0, 1.0)
15+
Mutate(g0, 1.0, 1.0, 1.0)
16+
Mutate(g0, 1.0, 1.0, 1.0)
17+
Mutate(g0, 1.0, 1.0, 1.0)
18+
Mutate(g0, 1.0, 1.0, 1.0)
19+
Mutate(g0, 1.0, 1.0, 1.0)
20+
Mutate(g0, 1.0, 1.0, 1.0)
21+
Mutate(g0, 1.0, 1.0, 1.0)
22+
Mutate(g0, 1.0, 1.0, 1.0)
23+
Mutate(g0, 1.0, 1.0, 1.0)
24+
Mutate(g0, 1.0, 1.0, 1.0)
25+
Mutate(g0, 1.0, 1.0, 1.0)
26+
Mutate(g0, 1.0, 1.0, 1.0)
27+
Mutate(g0, 1.0, 1.0, 1.0)
28+
Mutate(g0, 1.0, 1.0, 1.0)
29+
Mutate(g0, 1.0, 1.0, 1.0)
30+
Mutate(g0, 1.0, 1.0, 1.0)
31+
Mutate(g0, 1.0, 1.0, 1.0)
32+
Mutate(g0, 1.0, 1.0, 1.0)
33+
n0 := NewNeuralNetwork(g0)
34+
fmt.Println(n0.String())
35+
36+
fmt.Println("=Testing feedforward...")
37+
inputs := []float64{rand.NormFloat64(), rand.NormFloat64(), 1.0}
38+
fmt.Println("inputs:", inputs)
39+
outputs, err := n0.FeedForward(inputs)
40+
if err != nil {
41+
fmt.Println(err)
42+
}
43+
fmt.Println("outputs:", outputs)
44+
}
45+
46+
func TestNeuralNetwork(t *testing.T) {
47+
rand.Seed(0)
48+
NeuralNetworkUnitTest()
49+
}

0 commit comments

Comments
 (0)