Skip to content

Commit 80e6397

Browse files
committed
prep
1 parent 4254734 commit 80e6397

File tree

6 files changed

+33
-10
lines changed

6 files changed

+33
-10
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# tensorflow-lite-rest-server
2-
Expose tensorflow-lite models via a rest API, and currently object detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows.
2+
Expose tensorflow-lite models via a rest API, and currently object, face & scene detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows.
33

44
## Setup
55
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. Note on an RPi (only) it is necessary to manually install pip3, numpy, pillow.
@@ -58,6 +58,13 @@ To detect faces:
5858
curl -X POST -F image=@tests/faces.jpg 'http://localhost:5000/v1/vision/face'
5959
```
6060

61+
To run the scene:
62+
```
63+
curl -X POST -F image=@tests/cat.jpg 'http://localhost:5000/v1/vision/scene'
64+
or
65+
curl -X POST -F image=@tests/dog.jpg 'http://localhost:5000/v1/vision/scene'
66+
```
67+
6168
## Deepstack, Home Assistant & UI
6269
This API can be used as a drop in replacement for [deepstack object detection](https://github.com/robmarkcole/HASS-Deepstack-object) and [deepstack face detection](https://github.com/robmarkcole/HASS-Deepstack-face) (configuring `detect_only: True`) in Home Assistant. I also created a UI for viewing the predictions of the object detection model [here](https://github.com/robmarkcole/deepstack-ui).
6370

helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
"""
44

55

6-
def read_coco_labels(file_path):
6+
def read_labels(file_path):
77
"""
8-
Helper for loading coco_labels.txt
8+
Helper for loading labels.txt
99
"""
1010
with open(file_path, "r", encoding="utf-8") as f:
1111
lines = f.readlines()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## dogs vs cats
22
* Custom model trained on [teachable machine image classification](https://teachablemachine.withgoogle.com/train/image)
33
* Dataset: 30 cat and 30 dog images from [kaggle dogs vs cats](https://www.kaggle.com/c/dogs-vs-cats)
4-
* Input size: TBC
4+
* Input size: 224x224
55
* Type: tensorflow lite quantized

tests/cat.jpg

27.7 KB
Loading

tests/dog.jpg

10.9 KB
Loading

tflite-server.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tflite_runtime.interpreter as tflite
1212
from PIL import Image
1313

14-
from helpers import read_coco_labels
14+
from helpers import read_labels
1515

1616
app = flask.Flask(__name__)
1717

@@ -32,6 +32,10 @@
3232
OBJ_MODEL = "models/object_detection/mobilenet_ssd_v2_coco/mobilenet_ssd_v2_coco_quant_postprocess.tflite"
3333
OBJ_LABELS = "models/object_detection/mobilenet_ssd_v2_coco/coco_labels.txt"
3434

35+
SCENE_URL = "/v1/vision/scene"
36+
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
37+
SCENE_LABEL = "models/classification/dogs-vs-cats/labels.txt"
38+
3539

3640
@app.route("/")
3741
def info():
@@ -69,7 +73,9 @@ def predict_face():
6973
# Process image and get predictions
7074
face_interpreter.invoke()
7175
boxes = face_interpreter.get_tensor(face_output_details[0]["index"])[0]
72-
classes = face_interpreter.get_tensor(face_output_details[1]["index"])[0]
76+
classes = face_interpreter.get_tensor(face_output_details[1]["index"])[
77+
0
78+
]
7379
scores = face_interpreter.get_tensor(face_output_details[2]["index"])[0]
7480

7581
faces = []
@@ -141,7 +147,9 @@ def predict_object():
141147

142148

143149
if __name__ == "__main__":
144-
parser = argparse.ArgumentParser(description="Flask app exposing tflite models")
150+
parser = argparse.ArgumentParser(
151+
description="Flask app exposing tflite models"
152+
)
145153
parser.add_argument("--port", default=5000, type=int, help="port number")
146154
args = parser.parse_args()
147155

@@ -152,14 +160,22 @@ def predict_object():
152160
obj_output_details = obj_interpreter.get_output_details()
153161
obj_input_height = obj_input_details[0]["shape"][1]
154162
obj_input_width = obj_input_details[0]["shape"][2]
155-
obj_labels = read_coco_labels(OBJ_LABELS)
163+
obj_labels = read_labels(OBJ_LABELS)
156164

157165
# Setup face detection
158166
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
159167
face_interpreter.allocate_tensors()
160168
face_input_details = face_interpreter.get_input_details()
161169
face_output_details = face_interpreter.get_output_details()
162-
face_input_height = 320
163-
face_input_width = 320
170+
face_input_height = face_input_details[0]["shape"][1] # 320
171+
face_input_width = face_input_details[0]["shape"][2] # 320
172+
173+
# Setup face detection
174+
scene_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
175+
scene_interpreter.allocate_tensors()
176+
scene_input_details = scene_interpreter.get_input_details()
177+
scene_output_details = scene_interpreter.get_output_details()
178+
scene_input_height = scene_input_details[0]["shape"][1]
179+
scene_input_width = scene_input_details[0]["shape"][2]
164180

165181
app.run(host="0.0.0.0", debug=True, port=args.port)

0 commit comments

Comments
 (0)