5
5
import java .nio .file .Files ;
6
6
import java .nio .file .Path ;
7
7
import java .nio .file .Paths ;
8
+ import java .util .ArrayList ;
8
9
import java .util .List ;
9
10
10
11
public class Classificator {
@@ -23,37 +24,39 @@ public Classificator() {
23
24
modelGraph = new Graph ();
24
25
modelGraph .importGraphDef (graphData );
25
26
session = new Session (modelGraph );
26
-
27
- //Just print two main operations to look at shapes
28
- System .out .println (modelGraph .operation ("input" ).output (0 ));
29
- System .out .println (modelGraph .operation ("output" ).output (0 ));
30
27
} catch (Exception e ) {e .printStackTrace (); throw new RuntimeException (e );}
31
28
}
32
29
33
- public String classify (float [][][][] imageData ) {
30
+ public List < String > classify (float [][][][] imageData ) {
34
31
Tensor imageTensor = Tensor .create (imageData , Float .class );
35
- float [][] output = predict (imageTensor );
36
- return findPredictedLabel (output );
32
+ float [][] prediction = predict (imageTensor );
33
+ return findPredictedLabel (prediction );
37
34
}
38
35
39
36
private float [][] predict (Tensor imageTensor ) {
40
37
Tensor result = session .runner ()
41
38
.feed ("input" , imageTensor )
42
39
.fetch ("output" ).run ().get (0 );
40
+ int batchSize = (int )result .shape ()[0 ];
43
41
//create prediction buffer
44
- float [][] prediction = new float [1 ][1008 ];
42
+ float [][] prediction = new float [batchSize ][1008 ];
45
43
result .copyTo (prediction );
46
44
return prediction ;
47
45
}
48
46
49
- private String findPredictedLabel (float [][] prediction ) {
50
- int maxValueIndex = 0 ;
51
- for (int i = 1 ; i < prediction [0 ].length ; i ++) {
52
- if (prediction [0 ][maxValueIndex ] < prediction [0 ][i ]) {
53
- maxValueIndex = i ;
47
+ private List <String > findPredictedLabel (float [][] prediction ) {
48
+ List <String > result = new ArrayList <>();
49
+ int batchSize = prediction .length ;
50
+ for (int i = 0 ; i < batchSize ; i ++) {
51
+ //Finding maximum value for each predicted image
52
+ int maxValueIndex = 0 ;
53
+ for (int j = 1 ; j < prediction [i ].length ; j ++) {
54
+ if (prediction [i ][maxValueIndex ] < prediction [i ][j ]) {
55
+ maxValueIndex = j ;
56
+ }
54
57
}
58
+ result .add (labels .get (maxValueIndex ) + ": " + (prediction [i ][maxValueIndex ] * 100 ) + "%" );
55
59
}
56
- System .out .println (prediction [0 ][maxValueIndex ]);
57
- return labels .get (maxValueIndex );
60
+ return result ;
58
61
}
59
62
}
0 commit comments