Skip to content

Commit 9b3dcc7

Browse files
authored
Merge pull request #5 from dmi3coder/feature/dl4j
Logistic regression with dl4j
2 parents a3000ae + b5525a4 commit 9b3dcc7

File tree

11 files changed

+719
-354
lines changed

11 files changed

+719
-354
lines changed

.idea/workspace.xml

Lines changed: 371 additions & 232 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ dependencies {
2727
implementation 'com.google.guava:guava:28.0-jre'
2828
implementation 'io.reactivex.rxjava3:rxjava:3.0.0-RC2'
2929

30+
//dl4j
31+
implementation "org.deeplearning4j:deeplearning4j-core:0.9.1"
32+
implementation "org.nd4j:nd4j-native-platform:0.9.1"
33+
3034
// Use JUnit test framework
3135
testImplementation 'junit:junit:4.12'
3236
}

src/main/java/de/dmi3y/behaiv/Behaiv.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ public Behaiv setProvider(@Nonnull Provider provider) {
5050
return this;
5151
}
5252

53-
public Behaiv setBehaivNode() {
54-
return this;
55-
}
56-
5753
public Behaiv register(@Nonnull BehaivNode node, @Nullable String name) {
5854
if (node instanceof ActionableNode) {
5955
currentSession.captureLabel(name);
@@ -72,6 +68,10 @@ public void startCapturing(boolean predict) {
7268
currentSession.start(this);
7369
}
7470

71+
protected CaptureSession getCurrentSession() {
72+
return currentSession;
73+
}
74+
7575
@Override
7676
public void onFeaturesCaptured(List<Pair<Double, String>> features) {
7777
if (kernel.readyToPredict() && predict) {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package de.dmi3y.behaiv.kernel;
2+
3+
import org.apache.commons.lang3.ArrayUtils;
4+
import org.apache.commons.math3.util.Pair;
5+
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
6+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
7+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
8+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
9+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
10+
import org.deeplearning4j.nn.weights.WeightInit;
11+
import org.nd4j.linalg.activations.Activation;
12+
import org.nd4j.linalg.cpu.nativecpu.NDArray;
13+
import org.nd4j.linalg.learning.config.Nesterovs;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
import java.util.stream.Collectors;
18+
19+
public class LogisticRegressionKernel extends Kernel {
20+
21+
private List<String> labels = new ArrayList<>();
22+
private OutputLayer outputLayer;
23+
private MultiLayerNetwork network;
24+
25+
@Override
26+
public void fit(ArrayList<Pair<ArrayList<Double>, String>> data) {
27+
this.data = data;
28+
labels = this.data.stream().map(Pair::getSecond).distinct().collect(Collectors.toList());
29+
if (readyToPredict()) {
30+
outputLayer = new OutputLayer.Builder()
31+
.nIn(this.data.get(0).getFirst().size())
32+
.nOut(labels.size())
33+
.weightInit(WeightInit.DISTRIBUTION)
34+
.activation(Activation.SOFTMAX)
35+
.build();
36+
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(123).learningRate(0.1).iterations(100).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.9)) //High Level Configuration
37+
.list() //For configuring MultiLayerNetwork we call the list method
38+
.layer(0, outputLayer) // <----- output layer fed here
39+
.pretrain(true).backprop(true) //Pretraining and Backprop Configuration
40+
.build();//Building Configuration
41+
42+
network = new MultiLayerNetwork(config);
43+
network.init();
44+
45+
//features
46+
double[][] inputs = this.data.stream().map(Pair::getFirst).map(l -> l.toArray(new Double[0]))
47+
.map(ArrayUtils::toPrimitive)
48+
.toArray(double[][]::new);
49+
50+
//labels
51+
double[][] labelArray = new double[data.size()][labels.size()];
52+
for (int i = 0; i < data.size(); i++) {
53+
int dummyPos = labels.indexOf(data.get(i).getSecond());
54+
labelArray[i][dummyPos] = 1.0;
55+
}
56+
57+
NDArray inputResults = new NDArray(inputs);
58+
NDArray outputResults = new NDArray(labelArray);
59+
60+
network.fit(inputResults, outputResults);
61+
}
62+
63+
}
64+
65+
@Override
66+
public void updateSingle(ArrayList<Double> features, String label) {
67+
super.updateSingle(features, label);
68+
}
69+
70+
@Override
71+
public String predictOne(ArrayList<Double> features) {
72+
NDArray testInput = new NDArray(new double[][]{ArrayUtils.toPrimitive(features.toArray(new Double[0]))});
73+
int[] predict = network.predict(testInput);
74+
return labels.get(predict[0]);
75+
}
76+
}

src/main/java/de/dmi3y/behaiv/session/CaptureSession.java

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,29 @@ public CaptureSession(List<Provider> providers) {
1919
this.providers = providers;
2020
}
2121

22-
public void init() {
23-
24-
}
25-
2622
public void start(Behaiv behaiv) {
2723
new Thread(() -> {
28-
List<Double> capturedFeatures = providers.stream().flatMap(provider -> provider.getFeature().blockingGet().stream()).collect(Collectors.toList());
29-
List<String> capturedNames = providers.stream().flatMap((Provider provider) -> provider.availableFeatures().stream()).collect(Collectors.toList());
30-
if (capturedFeatures.size() != capturedNames.size()) {
31-
throw new InputMismatchException("Features size should match it's names");
32-
}
33-
34-
features = new ArrayList<>();
35-
for (int i = 0; i < capturedFeatures.size(); i++) {
36-
features.add(new Pair<>(capturedFeatures.get(i), capturedNames.get(i)));
37-
38-
}
39-
if (behaiv != null) {
40-
behaiv.onFeaturesCaptured(features);
41-
}
24+
startBlocking(behaiv);
4225
}).start();
4326
}
4427

28+
public void startBlocking(Behaiv behaiv) {
29+
List<Double> capturedFeatures = providers.stream().flatMap(provider -> provider.getFeature().blockingGet().stream()).collect(Collectors.toList());
30+
List<String> capturedNames = providers.stream().flatMap((Provider provider) -> provider.availableFeatures().stream()).collect(Collectors.toList());
31+
if (capturedFeatures.size() != capturedNames.size()) {
32+
throw new InputMismatchException("Features size should match it's names");
33+
}
34+
35+
features = new ArrayList<>();
36+
for (int i = 0; i < capturedFeatures.size(); i++) {
37+
features.add(new Pair<>(capturedFeatures.get(i), capturedNames.get(i)));
38+
39+
}
40+
if (behaiv != null) {
41+
behaiv.onFeaturesCaptured(features);
42+
}
43+
}
44+
4545
public void captureLabel(String name) {
4646
label = name;
4747
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package de.dmi3y.behaiv;
2+
3+
import de.dmi3y.behaiv.kernel.DummyKernel;
4+
import de.dmi3y.behaiv.session.CaptureSession;
5+
import org.junit.Before;
6+
import org.junit.Test;
7+
8+
import static org.junit.Assert.assertNotEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertNull;
11+
import static org.junit.Assert.assertTrue;
12+
13+
public class BehaivTest {
14+
15+
private Behaiv behaiv;
16+
private DummyKernel testKernel;
17+
18+
@Before
19+
public void setUp() throws Exception {
20+
testKernel = new DummyKernel();
21+
behaiv = Behaiv.with(testKernel);
22+
}
23+
24+
@Test
25+
public void setKernel() {
26+
DummyKernel newKernel = new DummyKernel();
27+
behaiv.setKernel(newKernel);
28+
assertTrue(!testKernel.equals(newKernel));
29+
30+
}
31+
32+
@Test
33+
public void stopCapturing_whenDiscard_sessionShouldBeNull() {
34+
behaiv.startCapturing(false);
35+
CaptureSession currentSession = behaiv.getCurrentSession();
36+
assertNotNull(currentSession);
37+
behaiv.stopCapturing(true);
38+
assertNull(behaiv.getCurrentSession());
39+
}
40+
}

src/test/java/de/dmi3y/behaiv/LibraryTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
import java.util.ArrayList;
1616

17-
import static de.dmi3y.behaiv.kernel.DummyKernelTest.GYM;
18-
import static de.dmi3y.behaiv.kernel.DummyKernelTest.HOME;
19-
import static de.dmi3y.behaiv.kernel.DummyKernelTest.JOG;
20-
import static de.dmi3y.behaiv.kernel.DummyKernelTest.WORK;
17+
import static de.dmi3y.behaiv.kernel.KernelTest.GYM;
18+
import static de.dmi3y.behaiv.kernel.KernelTest.HOME;
19+
import static de.dmi3y.behaiv.kernel.KernelTest.JOG;
20+
import static de.dmi3y.behaiv.kernel.KernelTest.WORK;
2121
import static org.junit.Assert.assertEquals;
2222

2323
public class LibraryTest {

src/test/java/de/dmi3y/behaiv/kernel/DummyKernelTest.java

Lines changed: 9 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,116 +5,29 @@
55

66
import java.util.ArrayList;
77

8-
import static org.junit.Assert.*;
8+
import static de.dmi3y.behaiv.kernel.KernelTest.HOME;
9+
import static de.dmi3y.behaiv.kernel.KernelTest.WORK;
10+
import static org.junit.Assert.assertEquals;
911

1012
public class DummyKernelTest {
1113

12-
public static Double[] HOME = {1.1, 1.2};
13-
public static Double[] GYM = {2.1, 2.2};
14-
public static Double[] JOG = {3.1, 3.2};
15-
public static Double[] WORK = {5.1, 5.2};
16-
1714

1815
@Test
1916
public void predictOne() {
20-
ArrayList<Double> list = new ArrayList<>();
21-
ArrayList<Pair<ArrayList<Double>, String>> data = new ArrayList<>();
22-
23-
24-
list.add(5 * 60 + 00.0);
25-
list.add(HOME[0]);
26-
list.add(HOME[1]);
27-
list.add(0.0);
28-
data.add(new Pair<>(list,"SELFIMPROVEMENT_SCREEN"));
29-
list = new ArrayList<>();
30-
list.add(5 * 60 + 10.0);
31-
list.add(HOME[0]);
32-
list.add(HOME[1]);
33-
list.add(0.0);
34-
data.add(new Pair<>(list,"SELFIMPROVEMENT_SCREEN"));
35-
list = new ArrayList<>();
36-
list.add(6 * 60 + 10.0);
37-
list.add(GYM[0]);
38-
list.add(GYM[1]);
39-
list.add(1.0);
40-
data.add(new Pair<>(list,"SPORT_SCREEN"));
41-
list = new ArrayList<>();
42-
list.add(7 * 60 + 30.0);
43-
list.add(HOME[0]);
44-
list.add(HOME[1]);
45-
list.add(1.0);
46-
data.add(new Pair<>(list,"SELFIMPROVEMENT_SCREEN"));
47-
list = new ArrayList<>();
48-
list.add(8 * 60 + 30.0);
49-
list.add(WORK[0]);
50-
list.add(WORK[1]);
51-
list.add(0.0);
52-
data.add(new Pair<>(list,"WORK_SCREEN"));
53-
list = new ArrayList<>();
54-
list.add(10 * 60 + 30.0);
55-
list.add(WORK[0]);
56-
list.add(WORK[1]);
57-
list.add(1.0);
58-
data.add(new Pair<>(list,"WORK_SCREEN"));
59-
list = new ArrayList<>();
60-
list.add(11 * 60 + 30.0);
61-
list.add(WORK[0]);
62-
list.add(WORK[1]);
63-
list.add(1.0);
64-
data.add(new Pair<>(list,"WORK_SCREEN"));
65-
list = new ArrayList<>();
66-
list.add(16 * 60 + 30.0);
67-
list.add(WORK[0]);
68-
list.add(WORK[1]);
69-
list.add(0.0);
70-
data.add(new Pair<>(list,"WORK_SCREEN"));
71-
list = new ArrayList<>();
72-
list.add(17 * 60 + 10.0);
73-
list.add(WORK[0]);
74-
list.add(WORK[1]);
75-
list.add(0.0);
76-
data.add(new Pair<>(list,"WORK_SCREEN"));
77-
list = new ArrayList<>();
78-
list.add(18 * 60 + 50.0);
79-
list.add(WORK[0]);
80-
list.add(WORK[1]);
81-
list.add(0.0);
82-
data.add(new Pair<>(list,"WORK_SCREEN"));
83-
list = new ArrayList<>();
84-
list.add(19 * 60 + 5.0);
85-
list.add(JOG[0]);
86-
list.add(JOG[1]);
87-
list.add(1.0);
88-
data.add(new Pair<>(list,"SPORT_SCREEN"));
89-
list = new ArrayList<>();
90-
list.add(19 * 60 + 10.0);
91-
list.add(JOG[0]);
92-
list.add(JOG[1]);
93-
list.add(1.0);
94-
data.add(new Pair<>(list,"SPORT_SCREEN"));
95-
list = new ArrayList<>();
96-
list.add(19 * 60 + 25.0);
97-
list.add(JOG[0]);
98-
list.add(JOG[1]);
99-
list.add(1.0);
100-
data.add(new Pair<>(list,"SPORT_SCREEN"));
101-
list = new ArrayList<>();
102-
list.add(21 * 60 + 00.0);
103-
list.add(HOME[0]);
104-
list.add(HOME[1]);
105-
list.add(0.0);
106-
data.add(new Pair<>(list,"ADD_SCREEN"));
107-
list = new ArrayList<>();
108-
DummyKernel dummyKernel = new DummyKernel();
17+
ArrayList<Pair<ArrayList<Double>, String>> data = KernelTest.getTrainingData();
18+
Kernel dummyKernel = new DummyKernel();
10919
dummyKernel.fit(data);
11020
ArrayList<Double> predictList = new ArrayList<>();
111-
predictList.add(10 * 60 + 30.0);
21+
predictList.add((10 * 60 + 30.0) / (24 * 60));
11222
predictList.add(WORK[0]);
11323
predictList.add(WORK[1]);
11424
predictList.add(1.0);
11525

11626
dummyKernel.update(null);
11727
String prediction = dummyKernel.predictOne(predictList);
11828
assertEquals("WORK_SCREEN", prediction);
29+
30+
//TODO predictOne fails in dummy because of re-usage of data
11931
}
32+
12033
}

0 commit comments

Comments
 (0)