-
Notifications
You must be signed in to change notification settings - Fork 1
/
Classifier.java
78 lines (57 loc) · 2.71 KB
/
Classifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package dev.prateek.com.dig_draw;
import android.content.Context;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.util.Arrays;
/**
* Created by prateek on 19-05-2018.
*/
public class Classifier {
private static final String LOG_TAG = Classifier.class.getSimpleName();
private static final String MODEL_PATH = "file:///android_asset/mnist_optimized.pb";
private static final int DIM_BATCH_SIZE = 1;
public static final int DIM_IMG_SIZE_HEIGHT = 28;
public static final int DIM_IMG_SIZE_WIDTH = 28;
private static final int DIM_PIXEL_SIZE = 1;
private static final int CATEGORY_COUNT = 10;
private static final String INPUT_NAME = "x";
private static final String OUTPUT_NAME = "output";
private static final String[] OUTPUT_NAMES = { OUTPUT_NAME };
private final int[] mImagePixels = new int[DIM_IMG_SIZE_HEIGHT * DIM_IMG_SIZE_WIDTH];
private final float[] mImageData = new float[DIM_IMG_SIZE_HEIGHT * DIM_IMG_SIZE_WIDTH];
private final float[] mResult = new float[CATEGORY_COUNT];
private TensorFlowInferenceInterface mInferenceInterface;
public Classifier(Context context) {
mInferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_PATH);
}
// classifying the input i.e image
public Result classify(Bitmap bitmap) {
convertBitmap(bitmap);
long startTime = SystemClock.uptimeMillis();
//feeding the input to tensor flow interface with image data array, size and dimension of images .. etc
mInferenceInterface.feed(INPUT_NAME, mImageData, DIM_BATCH_SIZE, DIM_IMG_SIZE_HEIGHT,
DIM_IMG_SIZE_WIDTH, DIM_PIXEL_SIZE);
mInferenceInterface.run(OUTPUT_NAMES);
mInferenceInterface.fetch(OUTPUT_NAME, mResult);
long endTime = SystemClock.uptimeMillis();
long timeCost = endTime - startTime;
Log.v(LOG_TAG, "classify(): result = " + Arrays.toString(mResult));
return new Result(mResult, timeCost);
}
public void close() {
mInferenceInterface.close();
}
//converting the bitmap input image into pixel by pixel array and convertToGreyScale
private void convertBitmap(Bitmap bitmap) {
bitmap.getPixels(mImagePixels, 0, bitmap.getWidth(), 0, 0,
bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < DIM_IMG_SIZE_HEIGHT * DIM_IMG_SIZE_WIDTH; i++) {
mImageData[i] = convertToGreyScale(mImagePixels[i]);
}
}
private float convertToGreyScale(int color) {
return (((color >> 16) & 0xFF) + ((color >> 8) & 0xFF) + (color & 0xFF)) / 3.0f / 255.0f;
}
}