Skip to content

Commit aee1400

Browse files
committed
refactor
1 parent 4f41a7b commit aee1400

File tree

3 files changed

+40
-39
lines changed

3 files changed

+40
-39
lines changed

src/main/java/Classificator.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.nio.file.Files;
66
import java.nio.file.Path;
77
import java.nio.file.Paths;
8+
import java.util.ArrayList;
89
import java.util.List;
910

1011
public class Classificator {
@@ -23,37 +24,39 @@ public Classificator() {
2324
modelGraph = new Graph();
2425
modelGraph.importGraphDef(graphData);
2526
session = new Session(modelGraph);
26-
27-
//Just print two main operations to look at shapes
28-
System.out.println(modelGraph.operation("input").output(0));
29-
System.out.println(modelGraph.operation("output").output(0));
3027
} catch(Exception e) {e.printStackTrace(); throw new RuntimeException(e);}
3128
}
3229

33-
public String classify(float[][][][] imageData) {
30+
public List<String> classify(float[][][][] imageData) {
3431
Tensor imageTensor = Tensor.create(imageData, Float.class);
35-
float[][] output = predict(imageTensor);
36-
return findPredictedLabel(output);
32+
float[][] prediction = predict(imageTensor);
33+
return findPredictedLabel(prediction);
3734
}
3835

3936
private float[][] predict(Tensor imageTensor) {
4037
Tensor result = session.runner()
4138
.feed("input", imageTensor)
4239
.fetch("output").run().get(0);
40+
int batchSize = (int)result.shape()[0];
4341
//create prediction buffer
44-
float[][] prediction = new float[1][1008];
42+
float[][] prediction = new float[batchSize][1008];
4543
result.copyTo(prediction);
4644
return prediction;
4745
}
4846

49-
private String findPredictedLabel(float[][] prediction) {
50-
int maxValueIndex = 0;
51-
for (int i = 1; i < prediction[0].length; i++) {
52-
if (prediction[0][maxValueIndex] < prediction[0][i]) {
53-
maxValueIndex = i;
47+
private List<String> findPredictedLabel(float[][] prediction) {
48+
List<String> result = new ArrayList<>();
49+
int batchSize = prediction.length;
50+
for (int i = 0; i < batchSize; i++) {
51+
//Finding maximum value for each predicted image
52+
int maxValueIndex = 0;
53+
for (int j = 1; j < prediction[i].length; j++) {
54+
if (prediction[i][maxValueIndex] < prediction[i][j]) {
55+
maxValueIndex = j;
56+
}
5457
}
58+
result.add(labels.get(maxValueIndex) + ": " + (prediction[i][maxValueIndex] * 100) + "%");
5559
}
56-
System.out.println(prediction[0][maxValueIndex]);
57-
return labels.get(maxValueIndex);
60+
return result;
5861
}
5962
}

src/main/java/ImageProcessor.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,27 @@ public class ImageProcessor {
2525
public ImageProcessor() {
2626
}
2727

28-
public float[][][][] loadAndNormalizeImage(String path) {
29-
IplImage origImg = cvLoadImage(getFullPath(path));
30-
//Creating image placeholder to put resized image data
31-
IplImage resizedImg = IplImage.create(width, height, origImg.depth(), origImg.nChannels());
32-
cvResize(origImg, resizedImg);
33-
return getRGBArray(resizedImg);
28+
public float[][][][] loadAndNormalizeImages(String... path) {
29+
//First dimension of the result is a number of images, because we may accept multiple paths
30+
float[][][][] result = new float[path.length][height][width][3];
31+
for (int i = 0; i < path.length; i++) {
32+
IplImage origImg = cvLoadImage(getFullPath(path[i]));
33+
//Creating image placeholder to put resized image data
34+
IplImage resizedImg = IplImage.create(width, height, origImg.depth(), origImg.nChannels());
35+
cvResize(origImg, resizedImg);
36+
result[i] = getRGBArray(resizedImg);
37+
}
38+
return result;
3439
}
3540

36-
private float[][][][] getRGBArray(IplImage image) {
37-
float[][][][] result = new float[1][image.height()][image.width()][3];
41+
private float[][][] getRGBArray(IplImage image) {
42+
float[][][] result = new float[image.height()][image.width()][3];
3843
for (int i = 0; i < image.height(); i++) {
3944
for (int j = 0; j < image.width(); j++) {
4045
CvScalar pixel = cvGet2D(image, i, j);
41-
result[0][i][j][0] = (float)(pixel.val(2) - mean) / scale; //R
42-
result[0][i][j][1] = (float)(pixel.val(1) - mean) / scale; //G
43-
result[0][i][j][2] = (float)(pixel.val(0) - mean) / scale; //B
46+
result[i][j][0] = (float)(pixel.val(2) - mean) / scale; //R
47+
result[i][j][1] = (float)(pixel.val(1) - mean) / scale; //G
48+
result[i][j][2] = (float)(pixel.val(0) - mean) / scale; //B
4449
}
4550
}
4651
return result;

src/main/java/Main.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
import org.tensorflow.*;
2-
3-
import java.io.PrintStream;
4-
import java.nio.ByteBuffer;
5-
import java.nio.file.Files;
6-
import java.nio.file.Path;
7-
import java.nio.file.Paths;
8-
import java.time.LocalDateTime;
9-
import java.util.Iterator;
101
import java.util.List;
112

123
public class Main {
134

145
public static void main(String[] args) throws Exception {
15-
/* ImageProcessor returns 4-dimensional float array -> float[1][224][224][3]
6+
/* ImageProcessor returns 4-dimensional float array -> float[?][224][224][3]
167
* First dimension of the array is a batch for neural network,
17-
* for this simple example I use just one image at a time for classification
18-
* The code can be easily modified to load and classify multiple images.
8+
* You may pass multiple image paths to ImageProcessor
199
* */
2010
ImageProcessor imageProcessor = new ImageProcessor();
2111
Classificator classificator = new Classificator();
22-
System.out.println(classificator.classify(imageProcessor.loadAndNormalizeImage("images/hyndai.jpg")));
12+
List<String> result = classificator.classify(imageProcessor.loadAndNormalizeImages("images/ship.jpg", "images/hyndai.jpg"));
13+
for(String label: result) {
14+
System.out.println(label);
15+
}
2316
}
2417
}

0 commit comments

Comments
 (0)