Skip to content

Commit

Permalink
Update Android Detect demo to use models exported using the Tensorflo…
Browse files Browse the repository at this point in the history
…w Object Detection API. Resolves tensorflow#6738.

PiperOrigin-RevId: 164802542
  • Loading branch information
tensorflower-gardener committed Aug 10, 2017
1 parent 22730fd commit 53aabd5
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 23 deletions.
10 changes: 10 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ new_http_archive(
],
)

new_http_archive(
name = "mobile_ssd",
build_file = "models.BUILD",
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
"http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
],
)

new_http_archive(
name = "mobile_multibox",
build_file = "models.BUILD",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/framework/register_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ limitations under the License.
#define TF_CALL_resource(m)
#define TF_CALL_complex64(m)
#define TF_CALL_int64(m) m(::tensorflow::int64)
#define TF_CALL_bool(m)
#define TF_CALL_bool(m) m(bool)

#define TF_CALL_qint8(m) m(::tensorflow::qint8)
#define TF_CALL_quint8(m) m(::tensorflow::quint8)
Expand All @@ -122,7 +122,7 @@ limitations under the License.
#define TF_CALL_resource(m)
#define TF_CALL_complex64(m)
#define TF_CALL_int64(m)
#define TF_CALL_bool(m)
#define TF_CALL_bool(m) m(bool)

#define TF_CALL_qint8(m)
#define TF_CALL_quint8(m)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/examples/android/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ filegroup(
name = "external_assets",
srcs = [
"@inception5h//:model_files",
"@mobile_multibox//:model_files",
"@mobile_ssd//:model_files",
"@stylize//:model_files",
],
)
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/examples/android/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ on API >= 14 devices.
model to classify camera frames in real-time, displaying the top results
in an overlay on the camera image.
2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
Demonstrates a model based on [Scalable Object Detection
using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to
localize and track people in the camera preview in real-time.
Demonstrates an SSD-Mobilenet model trained using the
[Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/object_detection/)
introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to
localize and track objects (from 80 categories) in the camera preview
in real-time.
3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
Uses a model based on [A Learned Representation For Artistic
Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
image to that of a number of different artists.

<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%">
Expand Down Expand Up @@ -149,7 +151,7 @@ and extract the archives yourself to the `assets` directory in the source tree:

```bash
BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models
for MODEL_ZIP in inception5h.zip mobile_multibox_v1a.zip stylize_v1.zip
for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip
do
curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP}
unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/
Expand Down
1 change: 1 addition & 0 deletions tensorflow/examples/android/download-models.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// hard coded model files
// LINT.IfChange
def models = ['inception5h.zip',
'ssd_mobilenet_v1_android_export.zip',
'mobile_multibox_v1a.zip',
'stylize_v1.zip']
// LINT.ThenChange(//tensorflow/examples/android/BUILD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import android.widget.Toast;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
Expand Down Expand Up @@ -62,6 +64,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
private static final String MB_LOCATION_FILE =
"file:///android_asset/multibox_location_priors.txt";

private static final int TF_OD_API_INPUT_SIZE = 300;
private static final String TF_OD_API_MODEL_FILE =
"file:///android_asset/ssd_mobilenet_v1_android_export.pb";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";

// Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and
// must be manually placed in the assets/ directory by the user.
// Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
Expand All @@ -73,15 +80,20 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
private static final String YOLO_OUTPUT_NAMES = "output";
private static final int YOLO_BLOCK_SIZE = 32;

// Default to the included multibox model.
private static final boolean USE_YOLO = false;

private static final int CROP_SIZE = USE_YOLO ? YOLO_INPUT_SIZE : MB_INPUT_SIZE;
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
// or YOLO.
private enum DetectorMode {
TF_OD_API, MULTIBOX, YOLO;
}
private static final DetectorMode MODE = DetectorMode.TF_OD_API;

// Minimum detection confidence to track a detection.
private static final float MINIMUM_CONFIDENCE = USE_YOLO ? 0.25f : 0.1f;
private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f;
private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;

private static final boolean MAINTAIN_ASPECT = USE_YOLO;
private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;

private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);

Expand Down Expand Up @@ -126,8 +138,8 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {

tracker = new MultiBoxTracker(this);


if (USE_YOLO) {
int cropSize = TF_OD_API_INPUT_SIZE;
if (MODE == DetectorMode.YOLO) {
detector =
TensorFlowYoloDetector.create(
getAssets(),
Expand All @@ -136,7 +148,8 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
YOLO_INPUT_NAME,
YOLO_OUTPUT_NAMES,
YOLO_BLOCK_SIZE);
} else {
cropSize = YOLO_INPUT_SIZE;
} else if (MODE == DetectorMode.MULTIBOX) {
detector =
TensorFlowMultiBoxDetector.create(
getAssets(),
Expand All @@ -147,6 +160,20 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
MB_INPUT_NAME,
MB_OUTPUT_LOCATIONS_NAME,
MB_OUTPUT_SCORES_NAME);
cropSize = MB_INPUT_SIZE;
} else {
try {
detector = TensorFlowObjectDetectionAPIModel.create(
getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
LOGGER.e("Exception initializing classifier!", e);
Toast toast =
Toast.makeText(
getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
toast.show();
finish();
}
}

previewWidth = size.getWidth();
Expand All @@ -162,12 +189,12 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbBytes = new int[previewWidth * previewHeight];
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(CROP_SIZE, CROP_SIZE, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
CROP_SIZE, CROP_SIZE,
cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);

cropToFrameTransform = new Matrix();
Expand Down Expand Up @@ -322,12 +349,19 @@ public void run() {
paint.setStyle(Style.STROKE);
paint.setStrokeWidth(2.0f);

float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
switch (MODE) {
case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break;
case MULTIBOX: minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX; break;
case YOLO: minimumConfidence = MINIMUM_CONFIDENCE_YOLO; break;
}

final List<Classifier.Recognition> mappedRecognitions =
new LinkedList<Classifier.Recognition>();

for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) {
if (location != null && result.getConfidence() >= minimumConfidence) {
canvas.drawRect(location, paint);

cropToFrameTransform.mapRect(location);
Expand All @@ -347,7 +381,7 @@ public void run() {
Trace.endSection();
}

protected void processImageRGBbytes(int[] rgbBytes ) {}
protected void processImageRGBbytes(int[] rgbBytes ) {}

@Override
protected int getLayoutId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
public class TensorFlowMultiBoxDetector implements Classifier {
private static final Logger LOGGER = new Logger();

// Only return this many results with at least this confidence.
// Only return this many results.
private static final int MAX_RESULTS = Integer.MAX_VALUE;

// Config values.
Expand Down
Loading

0 comments on commit 53aabd5

Please sign in to comment.