Skip to content

Proofed chapter 7 #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions CCSPiJ/src/chapter7/IrisTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
package chapter7;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class IrisTest {
public static final String IRIS_SETOSA = "Iris-setosa";
public static final String IRIS_VERSICOLOR = "Iris-versicolor";
public static final String IRIS_VIRGINICA = "Iris-virginica";

private List<double[]> irisParameters = new ArrayList<>();
private List<double[]> irisClassifications = new ArrayList<>();
private List<String> irisSpecies = new ArrayList<>();
Expand All @@ -32,19 +37,20 @@ public IrisTest() {
Collections.shuffle(irisDataset);
for (String[] iris : irisDataset) {
// first four items are parameters (doubles)
double[] parameters = new double[4];
for (int i = 0; i < parameters.length; i++) {
parameters[i] = Double.parseDouble(iris[i]);
}
double[] parameters = Arrays.stream(iris)
.limit(4)
.mapToDouble(Double::parseDouble)
.toArray();
irisParameters.add(parameters);
// last item is species
String species = iris[4];
if (species.equals("Iris-setosa")) {
irisClassifications.add(new double[] { 1.0, 0.0, 0.0 });
} else if (species.equals("Iris-versicolor")) {
irisClassifications.add(new double[] { 0.0, 1.0, 0.0 });
} else { // Iris-virginica
irisClassifications.add(new double[] { 0.0, 0.0, 1.0 });
switch (species) {
case IRIS_SETOSA :
irisClassifications.add(new double[] { 1.0, 0.0, 0.0 }); break;
case IRIS_VERSICOLOR :
irisClassifications.add(new double[] { 0.0, 1.0, 0.0 }); break;
default :
irisClassifications.add(new double[] { 0.0, 0.0, 1.0 }); break;
}
irisSpecies.add(species);
}
Expand All @@ -54,12 +60,12 @@ public IrisTest() {
public String irisInterpretOutput(double[] output) {
double max = Util.max(output);
if (max == output[0]) {
return "Iris-setosa";
} else if (max == output[1]) {
return "Iris-versicolor";
} else {
return "Iris-virginica";
return IRIS_SETOSA;
}
if (max == output[1]) {
return IRIS_VERSICOLOR;
}
return IRIS_VIRGINICA;
}

public Network<String>.Results classify() {
Expand Down
7 changes: 3 additions & 4 deletions CCSPiJ/src/chapter7/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@

public class Layer {
public Optional<Layer> previousLayer;
public List<Neuron> neurons;
public List<Neuron> neurons = new ArrayList<>();
public double[] outputCache;

public Layer(Optional<Layer> previousLayer, int numNeurons, double learningRate,
DoubleUnaryOperator activationFunction, DoubleUnaryOperator derivativeActivationFunction) {
this.previousLayer = previousLayer;
neurons = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < numNeurons; i++) {
double[] randomWeights = null;
if (previousLayer.isPresent()) {
Random random = new Random();
randomWeights = random.doubles(previousLayer.get().neurons.size()).toArray();
}
Neuron neuron = new Neuron(randomWeights, learningRate, activationFunction, derivativeActivationFunction);
Expand Down Expand Up @@ -63,7 +62,7 @@ public void calculateDeltasForOutputLayer(double[] expected) {
// should not be called on output layer
public void calculateDeltasForHiddenLayer(Layer nextLayer) {
for (int i = 0; i < neurons.size(); i++) {
final int index = i;
int index = i;
double[] nextWeights = nextLayer.neurons.stream().mapToDouble(n -> n.weights[index]).toArray();
double[] nextDeltas = nextLayer.neurons.stream().mapToDouble(n -> n.delta).toArray();
double sumWeightsAndDeltas = Util.dotProduct(nextWeights, nextDeltas);
Expand Down
9 changes: 2 additions & 7 deletions CCSPiJ/src/chapter7/Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import java.util.function.Function;

public class Network<T> {
private List<Layer> layers;
private List<Layer> layers = new ArrayList<>();

public Network(int[] layerStructure, double learningRate,
DoubleUnaryOperator activationFunction, DoubleUnaryOperator derivativeActivationFunction) {
if (layerStructure.length < 3) {
throw new IllegalArgumentException("Error: Should be at least 3 layers (1 input, 1 hidden, 1 output).");
}
layers = new ArrayList<>();
// input layer
Layer inputLayer = new Layer(Optional.empty(), layerStructure[0], learningRate, activationFunction,
derivativeActivationFunction);
Expand All @@ -47,11 +46,7 @@ public Network(int[] layerStructure, double learningRate,
// Pushes input data to the first layer, then output from the first
// as input to the second, second to the third, etc.
private double[] outputs(double[] input) {
double[] result = input;
for (Layer layer : layers) {
result = layer.outputs(result);
}
return result;
return layers.stream().reduce(input, (r, l) -> l.outputs(r), (r1, r2) -> r1);
}

// Figure out each neuron's changes based on the errors of the output
Expand Down
27 changes: 15 additions & 12 deletions CCSPiJ/src/chapter7/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package chapter7;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -63,22 +65,23 @@ public static void normalizeByFeatureScaling(List<double[]> dataset) {

// Load a CSV file into a List of String arrays
public static List<String[]> loadCSV(String filename) {
InputStream inputStream = Util.class.getResourceAsStream(filename);
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
return bufferedReader.lines().map(line -> line.split(","))
.collect(Collectors.toList());
try (InputStream inputStream = Util.class.getResourceAsStream(filename)) {
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
return bufferedReader.lines().map(line -> line.split(","))
.collect(Collectors.toList());
}
catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage(), e);
}
}

// Find the maximum in an array of doubles
public static double max(double[] numbers) {
double m = Double.MIN_VALUE;
for (double number : numbers) {
if (number > m) {
m = number;
}
}
return m;
return Arrays.stream(numbers)
.max()
.orElse(Double.MIN_VALUE);
}

}
28 changes: 15 additions & 13 deletions CCSPiJ/src/chapter7/WineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package chapter7;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

Expand All @@ -32,19 +33,20 @@ public WineTest() {
Collections.shuffle(wineDataset);
for (String[] wine : wineDataset) {
// last thirteen items are parameters (doubles)
double[] parameters = new double[13];
for (int i = 1; i < (parameters.length + 1); i++) {
parameters[i - 1] = Double.parseDouble(wine[i]);
}
double[] parameters = Arrays.stream(wine)
.skip(1)
.mapToDouble(Double::parseDouble)
.toArray();
wineParameters.add(parameters);
// first item is species
int species = Integer.parseInt(wine[0]);
if (species == 1) {
wineClassifications.add(new double[] { 1.0, 0.0, 0.0 });
} else if (species == 2) {
wineClassifications.add(new double[] { 0.0, 1.0, 0.0 });
} else { // 3
wineClassifications.add(new double[] { 0.0, 0.0, 1.0 });
switch (species) {
case 1 :
wineClassifications.add(new double[] { 1.0, 0.0, 0.0 }); break;
case 2 :
wineClassifications.add(new double[] { 0.0, 1.0, 0.0 }); break;
default :
wineClassifications.add(new double[] { 0.0, 0.0, 1.0 });; break;
}
wineSpecies.add(species);
}
Expand All @@ -55,11 +57,11 @@ public Integer wineInterpretOutput(double[] output) {
double max = Util.max(output);
if (max == output[0]) {
return 1;
} else if (max == output[1]) {
}
if (max == output[1]) {
return 2;
} else {
return 3;
}
return 3;
}

public Network<Integer>.Results classify() {
Expand Down
2 changes: 0 additions & 2 deletions CCSPiJ/src/module-info.java

This file was deleted.