Skip to content

Commit 73f43cd

Browse files
authored
Merge pull request robmarkcole#25 from robmarkcole/adopt-fastapi
Adopt fastapi
2 parents 047cce2 + d8dedd5 commit 73f43cd

File tree

3 files changed

+90
-95
lines changed

3 files changed

+90
-95
lines changed

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Expose tensorflow-lite models via a rest API. Currently object, face & scene detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows. A service can be used to have the server run automatically on an RPi.
33

44
## Setup
5-
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. Note on an RPi (only) it is necessary to manually install pip3, numpy, pillow.
5+
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. **Note** on an RPi (only) it is necessary to system wide install pip3, numpy, pillow.
66

77
All instructions for mac:
88
```
@@ -18,16 +18,16 @@ For convenience a couple of models are included in this repo and used by default
1818
If you want to create custom models, there is the easy way, and the longer but more flexible way. The easy way is to use [teachablemachine](https://teachablemachine.withgoogle.com/train/image), which I have done in this repo for the dogs-vs-cats model. The teachablemachine service is limited to image classification but is very straightforward to use. The longer way allows you to use any neural network architecture to produce a tensorflow model, which you then convert to am optimized tflite model. An example of this approach is described in [this article](https://towardsdatascience.com/inferences-from-a-tf-lite-model-transfer-learning-on-a-pre-trained-model-e16e7c5f0ee6), or jump straight [to the code](https://github.com/arshren/TFLite/blob/master/Transfer%20Learning%20with%20TFLite-Copy1.ipynb).
1919

2020
## Usage
21-
Start the tflite-server on port 5000 (default is port 5000):
21+
Start the tflite-server on port 5000 :
2222
```
23-
(venv) $ python3 tflite-server.py --port 5000
23+
(venv) $ uvicorn tflite-server:app --reload --port 5000
2424
```
2525

26-
You can check that the tflite-server is running by visiting `http://ip:5000/` from any machine, where `ip` is the ip address of the host (`localhost` if querying from the same machine).
26+
You can check that the tflite-server is running by visiting `http://ip:5000/` from any machine, where `ip` is the ip address of the host (`localhost` if querying from the same machine). The docs can be viewed at `http://localhost:5000/docs`
2727

2828
Post an image to detecting objects via cURL:
2929
```
30-
curl -X POST -F image=@tests/people_car.jpg 'http://localhost:5000/v1/vision/detection'
30+
curl -X POST "http://localhost:5000/v1/vision/detection" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/people_car.jpg;type=image/jpeg"
3131
```
3232
Which should return:
3333
```
@@ -57,25 +57,25 @@ Which should return:
5757

5858
To detect faces:
5959
```
60-
curl -X POST -F image=@tests/faces.jpg 'http://localhost:5000/v1/vision/face'
60+
curl -X POST "http://localhost:5000/v1/vision/face" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/faces.jpg;type=image/jpeg"
6161
```
6262

6363
To detect the scene (dogs vs cats model):
6464
```
65-
curl -X POST -F image=@tests/cat.jpg 'http://localhost:5000/v1/vision/scene'
65+
curl -X POST "http://localhost:5000/v1/vision/scene" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/cat.jpg;type=image/jpeg"
6666
```
6767

6868
## Add tflite-server as a service
69-
You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the rpi using: ```sudo nano /etc/systemd/system/tflite-server.service```
69+
You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the RPi using: ```sudo nano /etc/systemd/system/tflite-server.service```
7070

7171
Entering the following (adapted for your `tflite-server.py` file location and args):
7272
```
7373
[Unit]
74-
Description=Flask app exposing tensorflow lite models
74+
Description=App exposing tensorflow lite models
7575
After=network.target
7676
7777
[Service]
78-
ExecStart=/usr/bin/python3 -u tflite-server.py
78+
ExecStart=/usr/bin/uvicorn tflite-server:app --reload --port 5000 # check
7979
WorkingDirectory=/home/pi/github/tensorflow-lite-rest-server
8080
StandardOutput=inherit
8181
StandardError=inherit

requirements.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
Flask
21
Pillow # manual install on rpi
3-
numpy # manual install on rpi
2+
numpy # manual install on rpi
3+
fastapi==0.59.0
4+
uvicorn==0.11.5
5+
python-multipart

tflite-server.py

Lines changed: 76 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
"""
2-
Expose a tflite models via a rest API.
2+
Expose tflite models via a rest API.
33
"""
4-
import argparse
54
import io
65
import logging
76
import sys
87

9-
import flask
108
import numpy as np
119
import tflite_runtime.interpreter as tflite
10+
from fastapi import FastAPI, File, HTTPException, UploadFile
1211
from PIL import Image
1312

14-
from helpers import read_labels, set_input_tensor, classify_image
13+
from helpers import classify_image, read_labels, set_input_tensor
1514

16-
app = flask.Flask(__name__)
15+
app = FastAPI()
1716

1817
LOGFORMAT = "%(asctime)s %(levelname)s %(name)s : %(message)s"
1918
logging.basicConfig(
@@ -36,29 +35,52 @@
3635
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
3736
SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt"
3837

39-
40-
@app.route("/")
41-
def info():
38+
# Setup object detection
39+
obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL)
40+
obj_interpreter.allocate_tensors()
41+
obj_input_details = obj_interpreter.get_input_details()
42+
obj_output_details = obj_interpreter.get_output_details()
43+
obj_input_height = obj_input_details[0]["shape"][1]
44+
obj_input_width = obj_input_details[0]["shape"][2]
45+
obj_labels = read_labels(OBJ_LABELS)
46+
47+
# Setup face detection
48+
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
49+
face_interpreter.allocate_tensors()
50+
face_input_details = face_interpreter.get_input_details()
51+
face_output_details = face_interpreter.get_output_details()
52+
face_input_height = face_input_details[0]["shape"][1] # 320
53+
face_input_width = face_input_details[0]["shape"][2] # 320
54+
55+
# Setup face detection
56+
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
57+
scene_interpreter.allocate_tensors()
58+
scene_input_details = scene_interpreter.get_input_details()
59+
scene_output_details = scene_interpreter.get_output_details()
60+
scene_input_height = scene_input_details[0]["shape"][1]
61+
scene_input_width = scene_input_details[0]["shape"][2]
62+
scene_labels = read_labels(SCENE_LABELS)
63+
64+
65+
@app.get("/")
66+
async def info():
4267
return f"""
43-
Object detection model: {OBJ_MODEL.split("/")[-2]} \n
44-
Face detection model: {FACE_MODEL.split("/")[-2]} \n
45-
Scene model: {SCENE_MODEL.split("/")[-2]} \n
46-
""".replace(
47-
"\n", "<br>"
48-
)
68+
Object detection model: {OBJ_MODEL.split("/")[-2]}
69+
Face detection model: {FACE_MODEL.split("/")[-2]}
70+
Scene model: {SCENE_MODEL.split("/")[-2]}
71+
"""
4972

5073

51-
@app.route(FACE_DETECTION_URL, methods=["POST"])
52-
def predict_face():
74+
@app.post(FACE_DETECTION_URL)
75+
async def predict_face(file: UploadFile = File(...)):
5376
data = {"success": False}
54-
if not flask.request.method == "POST":
55-
return
56-
57-
if flask.request.files.get("image"):
58-
# Open image and get bytes and size
59-
image_file = flask.request.files["image"]
60-
image_bytes = image_file.read()
61-
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
77+
if file.content_type.startswith("image/") is False:
78+
raise HTTPException(
79+
status_code=400, detail=f"File '{file.filename}' is not an image."
80+
)
81+
try:
82+
contents = await file.read()
83+
image = Image.open(io.BytesIO(contents)) # A PIL image
6284
image_width = image.size[0]
6385
image_height = image.size[1]
6486

@@ -92,20 +114,22 @@ def predict_face():
92114

93115
data["predictions"] = faces
94116
data["success"] = True
95-
return flask.jsonify(data)
117+
return data
118+
except:
119+
e = sys.exc_info()[1]
120+
raise HTTPException(status_code=500, detail=str(e))
96121

97122

98-
@app.route(OBJ_DETECTION_URL, methods=["POST"])
99-
def predict_object():
123+
@app.post(OBJ_DETECTION_URL)
124+
async def predict_object(file: UploadFile = File(...)):
100125
data = {"success": False}
101-
if not flask.request.method == "POST":
102-
return
103-
104-
if flask.request.files.get("image"):
105-
# Open image and get bytes and size
106-
image_file = flask.request.files["image"]
107-
image_bytes = image_file.read()
108-
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
126+
if file.content_type.startswith("image/") is False:
127+
raise HTTPException(
128+
status_code=400, detail=f"File '{file.filename}' is not an image."
129+
)
130+
try:
131+
contents = await file.read()
132+
image = Image.open(io.BytesIO(contents)) # A PIL image
109133
image_width = image.size[0]
110134
image_height = image.size[1]
111135

@@ -136,20 +160,22 @@ def predict_object():
136160

137161
data["predictions"] = objects
138162
data["success"] = True
139-
return flask.jsonify(data)
163+
return data
164+
except:
165+
e = sys.exc_info()[1]
166+
raise HTTPException(status_code=500, detail=str(e))
140167

141168

142-
@app.route(SCENE_URL, methods=["POST"])
143-
def predict_scene():
169+
@app.post(SCENE_URL)
170+
async def predict_scene(file: UploadFile = File(...)):
144171
data = {"success": False}
145-
if not flask.request.method == "POST":
146-
return
147-
148-
if flask.request.files.get("image"):
149-
# Open image and get bytes and size
150-
image_file = flask.request.files["image"]
151-
image_bytes = image_file.read()
152-
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
172+
if file.content_type.startswith("image/") is False:
173+
raise HTTPException(
174+
status_code=400, detail=f"File '{file.filename}' is not an image."
175+
)
176+
try:
177+
contents = await file.read()
178+
image = Image.open(io.BytesIO(contents)) # A PIL image
153179
# Format data and send to interpreter
154180
resized_image = image.resize(
155181
(scene_input_width, scene_input_height), Image.ANTIALIAS
@@ -164,40 +190,7 @@ def predict_scene():
164190
data["label"] = scene_labels[label_id]
165191
data["confidence"] = prob
166192
data["success"] = True
167-
return flask.jsonify(data)
168-
169-
170-
if __name__ == "__main__":
171-
parser = argparse.ArgumentParser(
172-
description="Flask app exposing tflite models"
173-
)
174-
parser.add_argument("--port", default=5000, type=int, help="port number")
175-
args = parser.parse_args()
176-
177-
# Setup object detection
178-
obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL)
179-
obj_interpreter.allocate_tensors()
180-
obj_input_details = obj_interpreter.get_input_details()
181-
obj_output_details = obj_interpreter.get_output_details()
182-
obj_input_height = obj_input_details[0]["shape"][1]
183-
obj_input_width = obj_input_details[0]["shape"][2]
184-
obj_labels = read_labels(OBJ_LABELS)
185-
186-
# Setup face detection
187-
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
188-
face_interpreter.allocate_tensors()
189-
face_input_details = face_interpreter.get_input_details()
190-
face_output_details = face_interpreter.get_output_details()
191-
face_input_height = face_input_details[0]["shape"][1] # 320
192-
face_input_width = face_input_details[0]["shape"][2] # 320
193-
194-
# Setup face detection
195-
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
196-
scene_interpreter.allocate_tensors()
197-
scene_input_details = scene_interpreter.get_input_details()
198-
scene_output_details = scene_interpreter.get_output_details()
199-
scene_input_height = scene_input_details[0]["shape"][1]
200-
scene_input_width = scene_input_details[0]["shape"][2]
201-
scene_labels = read_labels(SCENE_LABELS)
202-
203-
app.run(host="0.0.0.0", debug=True, port=args.port)
193+
return data
194+
except:
195+
e = sys.exc_info()[1]
196+
raise HTTPException(status_code=500, detail=str(e))

0 commit comments

Comments
 (0)