Skip to content

nex3z/tfmobile-mnist-android

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST with TensorFlow Mobile on Android

This project demonstrates how to use TensorFlow Mobile on Android for handwritten digits classification from MNIST.

Prebuilt APK can be downloaded from here.

If you are interested in a TensorFlow Lite version, please refer to tflite-mnist-android.

How to build from scratch

Requirement

  • Python 3.6, TensorFlow 1.8.0
  • Android Studio 3.0, Gradle 4.1

Step 1. Training

The model is defined in mnist.py, run the following command to train the model.

python train.py --model_dir ./saved_model --iterations 10000

train.py uses a simple convontional neural network. train_bn.py provides a bigger network with batch normalization, which hopefully would achieve 99.5% accuracy on validation set within 10000 iterations as shown below.

After training, a collection of checkpoint files and a frozen GraphDef file mnist.pb will be generated in ./saved_model.

You can test the model on test set using the command below.

python test.py --model_dir ./saved_model

A pre-trained model can be downloaded from here.

Step 2. Model optimization

TensorFlow provides optimize_for_inference.py to optimize the model by removing parts of a graph that are only needed for training.

Navigate to the TensorFlow repository directory, run the following command to optimize the model.

python tensorflow/python/tools/optimize_for_inference.py \
    --input=model_path/mnist.pb \
    --output=output_path/mnist_optimized.pb \
    --input_names=x \
    --output_names=output   

The input argument should point to the TensorFlow GraphDef file (mnist.pb) trained in Step 1. The output argument specifies the location for the optimized model.

Notice that the mnist.pb generated by train.py is already frozen, otherwise we will have to freeze the graph first by using freeze_graph.py before optimization.

A optimized model file can be downloaded from here.

Step 3. Build Android app

Copy the mnist_optimized.pb generated in Step 2 to /android/app/src/main/assets, then build and run the app.

The Classifer creates a TensorFlowInferenceInterface from mnist_optimized.pb. The TensorFlowInferenceInterface provides an interface for inference and performance summarization, which is included in the following library.

implementation "org.tensorflow:tensorflow-android:1.8.0"

Credits