Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Expose tensorflow-lite models via a rest API. Currently object, face & scene detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows. A service can be used to have the server run automatically on an RPi.

## Setup
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. Note on an RPi (only) it is necessary to manually install pip3, numpy, pillow.
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. **Note** on an RPi (only) it is necessary to system wide install pip3, numpy, pillow.

All instructions for mac:
```
Expand All @@ -18,16 +18,16 @@ For convenience a couple of models are included in this repo and used by default
If you want to create custom models, there is the easy way, and the longer but more flexible way. The easy way is to use [teachablemachine](https://teachablemachine.withgoogle.com/train/image), which I have done in this repo for the dogs-vs-cats model. The teachablemachine service is limited to image classification but is very straightforward to use. The longer way allows you to use any neural network architecture to produce a tensorflow model, which you then convert to am optimized tflite model. An example of this approach is described in [this article](https://towardsdatascience.com/inferences-from-a-tf-lite-model-transfer-learning-on-a-pre-trained-model-e16e7c5f0ee6), or jump straight [to the code](https://github.com/arshren/TFLite/blob/master/Transfer%20Learning%20with%20TFLite-Copy1.ipynb).

## Usage
Start the tflite-server on port 5000 (default is port 5000):
Start the tflite-server on port 5000 :
```
(venv) $ python3 tflite-server.py --port 5000
(venv) $ uvicorn tflite-server:app --reload --port 5000
```

You can check that the tflite-server is running by visiting `http://ip:5000/` from any machine, where `ip` is the ip address of the host (`localhost` if querying from the same machine).
You can check that the tflite-server is running by visiting `http://ip:5000/` from any machine, where `ip` is the ip address of the host (`localhost` if querying from the same machine). The docs can be viewed at `http://localhost:5000/docs`

Post an image to detecting objects via cURL:
```
curl -X POST -F image=@tests/people_car.jpg 'http://localhost:5000/v1/vision/detection'
curl -X POST "http://localhost:5000/v1/vision/detection" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/people_car.jpg;type=image/jpeg"
```
Which should return:
```
Expand Down Expand Up @@ -57,25 +57,25 @@ Which should return:

To detect faces:
```
curl -X POST -F image=@tests/faces.jpg 'http://localhost:5000/v1/vision/face'
curl -X POST "http://localhost:5000/v1/vision/face" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/faces.jpg;type=image/jpeg"
```

To detect the scene (dogs vs cats model):
```
curl -X POST -F image=@tests/cat.jpg 'http://localhost:5000/v1/vision/scene'
curl -X POST "http://localhost:5000/v1/vision/scene" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@tests/cat.jpg;type=image/jpeg"
```

## Add tflite-server as a service
You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the rpi using: ```sudo nano /etc/systemd/system/tflite-server.service```
You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the RPi using: ```sudo nano /etc/systemd/system/tflite-server.service```

Entering the following (adapted for your `tflite-server.py` file location and args):
```
[Unit]
Description=Flask app exposing tensorflow lite models
Description=App exposing tensorflow lite models
After=network.target

[Service]
ExecStart=/usr/bin/python3 -u tflite-server.py
ExecStart=/usr/bin/uvicorn tflite-server:app --reload --port 5000 # check
WorkingDirectory=/home/pi/github/tensorflow-lite-rest-server
StandardOutput=inherit
StandardError=inherit
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Flask
Pillow # manual install on rpi
numpy # manual install on rpi
numpy # manual install on rpi
fastapi==0.59.0
uvicorn==0.11.5
python-multipart
159 changes: 76 additions & 83 deletions tflite-server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""
Expose a tflite models via a rest API.
Expose tflite models via a rest API.
"""
import argparse
import io
import logging
import sys

import flask
import numpy as np
import tflite_runtime.interpreter as tflite
from fastapi import FastAPI, File, HTTPException, UploadFile
from PIL import Image

from helpers import read_labels, set_input_tensor, classify_image
from helpers import classify_image, read_labels, set_input_tensor

app = flask.Flask(__name__)
app = FastAPI()

LOGFORMAT = "%(asctime)s %(levelname)s %(name)s : %(message)s"
logging.basicConfig(
Expand All @@ -36,29 +35,52 @@
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt"


@app.route("/")
def info():
# Setup object detection
obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL)
obj_interpreter.allocate_tensors()
obj_input_details = obj_interpreter.get_input_details()
obj_output_details = obj_interpreter.get_output_details()
obj_input_height = obj_input_details[0]["shape"][1]
obj_input_width = obj_input_details[0]["shape"][2]
obj_labels = read_labels(OBJ_LABELS)

# Setup face detection
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
face_interpreter.allocate_tensors()
face_input_details = face_interpreter.get_input_details()
face_output_details = face_interpreter.get_output_details()
face_input_height = face_input_details[0]["shape"][1] # 320
face_input_width = face_input_details[0]["shape"][2] # 320

# Setup face detection
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
scene_interpreter.allocate_tensors()
scene_input_details = scene_interpreter.get_input_details()
scene_output_details = scene_interpreter.get_output_details()
scene_input_height = scene_input_details[0]["shape"][1]
scene_input_width = scene_input_details[0]["shape"][2]
scene_labels = read_labels(SCENE_LABELS)


@app.get("/")
async def info():
return f"""
Object detection model: {OBJ_MODEL.split("/")[-2]} \n
Face detection model: {FACE_MODEL.split("/")[-2]} \n
Scene model: {SCENE_MODEL.split("/")[-2]} \n
""".replace(
"\n", "<br>"
)
Object detection model: {OBJ_MODEL.split("/")[-2]}
Face detection model: {FACE_MODEL.split("/")[-2]}
Scene model: {SCENE_MODEL.split("/")[-2]}
"""


@app.route(FACE_DETECTION_URL, methods=["POST"])
def predict_face():
@app.post(FACE_DETECTION_URL)
async def predict_face(file: UploadFile = File(...)):
data = {"success": False}
if not flask.request.method == "POST":
return

if flask.request.files.get("image"):
# Open image and get bytes and size
image_file = flask.request.files["image"]
image_bytes = image_file.read()
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
if file.content_type.startswith("image/") is False:
raise HTTPException(
status_code=400, detail=f"File '{file.filename}' is not an image."
)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)) # A PIL image
image_width = image.size[0]
image_height = image.size[1]

Expand Down Expand Up @@ -92,20 +114,22 @@ def predict_face():

data["predictions"] = faces
data["success"] = True
return flask.jsonify(data)
return data
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))


@app.route(OBJ_DETECTION_URL, methods=["POST"])
def predict_object():
@app.post(OBJ_DETECTION_URL)
async def predict_object(file: UploadFile = File(...)):
data = {"success": False}
if not flask.request.method == "POST":
return

if flask.request.files.get("image"):
# Open image and get bytes and size
image_file = flask.request.files["image"]
image_bytes = image_file.read()
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
if file.content_type.startswith("image/") is False:
raise HTTPException(
status_code=400, detail=f"File '{file.filename}' is not an image."
)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)) # A PIL image
image_width = image.size[0]
image_height = image.size[1]

Expand Down Expand Up @@ -136,20 +160,22 @@ def predict_object():

data["predictions"] = objects
data["success"] = True
return flask.jsonify(data)
return data
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))


@app.route(SCENE_URL, methods=["POST"])
def predict_scene():
@app.post(SCENE_URL)
async def predict_scene(file: UploadFile = File(...)):
data = {"success": False}
if not flask.request.method == "POST":
return

if flask.request.files.get("image"):
# Open image and get bytes and size
image_file = flask.request.files["image"]
image_bytes = image_file.read()
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
if file.content_type.startswith("image/") is False:
raise HTTPException(
status_code=400, detail=f"File '{file.filename}' is not an image."
)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)) # A PIL image
# Format data and send to interpreter
resized_image = image.resize(
(scene_input_width, scene_input_height), Image.ANTIALIAS
Expand All @@ -164,40 +190,7 @@ def predict_scene():
data["label"] = scene_labels[label_id]
data["confidence"] = prob
data["success"] = True
return flask.jsonify(data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Flask app exposing tflite models"
)
parser.add_argument("--port", default=5000, type=int, help="port number")
args = parser.parse_args()

# Setup object detection
obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL)
obj_interpreter.allocate_tensors()
obj_input_details = obj_interpreter.get_input_details()
obj_output_details = obj_interpreter.get_output_details()
obj_input_height = obj_input_details[0]["shape"][1]
obj_input_width = obj_input_details[0]["shape"][2]
obj_labels = read_labels(OBJ_LABELS)

# Setup face detection
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
face_interpreter.allocate_tensors()
face_input_details = face_interpreter.get_input_details()
face_output_details = face_interpreter.get_output_details()
face_input_height = face_input_details[0]["shape"][1] # 320
face_input_width = face_input_details[0]["shape"][2] # 320

# Setup face detection
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
scene_interpreter.allocate_tensors()
scene_input_details = scene_interpreter.get_input_details()
scene_output_details = scene_interpreter.get_output_details()
scene_input_height = scene_input_details[0]["shape"][1]
scene_input_width = scene_input_details[0]["shape"][2]
scene_labels = read_labels(SCENE_LABELS)

app.run(host="0.0.0.0", debug=True, port=args.port)
return data
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))