Skip to content

Add lite version of the license plate reader example #994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/keras/document-denoiser/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Then run the following piped commands
```bash
curl "${ENDPOINT}" -X POST -H "Content-Type: application/json" -d '{"url":"'${IMAGE_URL}'"}' |
sed 's/"//g' |
base64 -d >> prediction.png
base64 -d > prediction.png
```

Once this has run, we'll see a `prediction.png` file saved to the disk. This is the result.
Expand Down
75 changes: 62 additions & 13 deletions examples/tensorflow/license-plate-reader/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,53 @@ Out of these three models (*YOLOv3*, *CRAFT* and *CRNN*) only *YOLOv3* has been

The other two models, *CRAFT* and *CRNN*, can be found in [keras-ocr](https://github.com/faustomorales/keras-ocr).

## Deploying
## Deployment - Lite Version

The recommended number of instances to run this smoothly on a video stream is about 12 GPU instances (2 GPU instances for *YOLOv3* and 10 for *CRNN* + *CRAFT*). `cortex.yaml` is already set up to use these 12 instances. Note: this is the optimal number of instances when using the `g4dn.xlarge` instance type. For the client to work smoothly, the number of workers per replica can be adjusted, especially for `p3` or `g4` instances, where the GPU has a lot of compute capacity.
A lite version of the deployment is available with `cortex_lite.yaml`. The lite version accepts an image as input and returns an image with the recognized license plates overlayed on top. A single GPU is required for this deployment (i.e. `g4dn.xlarge`).

Once the cortex cluster is created, run
```bash
cortex deploy cortex_lite.yaml
```

And monitor the API with
```bash
cortex get --watch
```

To run an inference on the lite version, the only 3 tools you need are `curl`, `sed` and `base64`. This API expects an URL pointing to an image onto which the inferencing is done. This includes the detection of license plates with *YOLOv3* and the recognition part with *CRAFT* + *CRNN* models.

Export the endpoint & the image's URL by running
```bash
export ENDPOINT=your-api-endpoint
export IMAGE_URL=https://i.imgur.com/r8xdI7P.png
```

Then run the following piped commands
```
curl "${ENDPOINT}" -X POST -H "Content-Type: application/json" -d '{"url":"'${IMAGE_URL}'"}' |
sed 's/"//g' |
base64 -d > prediction.jpg
```

The resulting image is the same as the one in [Verifying the Deployed APIs](#verifying-the-deployed-apis).

For another prediction, let's use a generic image from the web. Export [this image's URL link](https://i.imgur.com/mYuvMOs.jpg) and re-run the prediction. This is what we get.

![annotated sample image](https://i.imgur.com/tg1PE1E.jpg)

*The above prediction has the bounding boxes colored differently to distinguish them from the cars' red bodies*

## Deployment - Full Version

The recommended number of instances to run this smoothly on a video stream is about 12 GPU instances (2 GPU instances for *YOLOv3* and 10 for *CRNN* + *CRAFT*). `cortex_full.yaml` is already set up to use these 12 instances. Note: this is the optimal number of instances when using the `g4dn.xlarge` instance type. For the client to work smoothly, the number of workers per replica can be adjusted, especially for `p3` or `g4` instances, where the GPU has a lot of compute capacity.

If you don't have access to this many GPU-equipped instances, you could just lower the number and expect dropped frames. It will still prove the point, albeit at a much lower framerate and with higher latency. More on that [here](https://github.com/RobertLucian/cortex-license-plate-reader-client).

Then after the cortex cluster is created, run

```bash
cortex deploy
cortex deploy cortex_full.yaml
```

And monitor the APIs with
Expand All @@ -42,10 +79,6 @@ And monitor the APIs with
cortex get --watch
```

## Launching the Client

### Verifying the Deployed APIs

We can run the inference on a sample image to verify that both APIs are working as expected before we move on to running the client. Here is an example image:

![sample image](https://i.imgur.com/r8xdI7P.png)
Expand Down Expand Up @@ -81,13 +114,29 @@ Once the APIs are up and running, launch the streaming client by following the i

## Customization/Optimization

### Uploading the SavedModel to S3
### Uploading the Model to S3

The only model to upload to an S3 bucket (for Cortex to deploy) is the *YOLOv3* model. The other two models are downloaded automatically upon deploying the service.

If you would like to host the model from your own bucket, or if you want to fine tune the model for your needs, here's what you can do.

The only model that has to be uploaded to an S3 bucket (for Cortex to deploy) is the *YOLOv3* model. The other two models are downloaded automatically upon deploying the service.
#### Lite Version

*Note: The Keras model from [here](https://github.com/experiencor/keras-yolo3) has been converted to SavedModel model instead.*
Download the *Keras* model:

```bash
wget -O license_plate.h5 "https://www.dropbox.com/s/vsvgoyricooksyv/license_plate.h5?dl=0"
```

And then upload it to your bucket (also make sure [cortex_lite.yaml](cortex_lite.yaml) points to this bucket):

```bash
BUCKET=my-bucket
YOLO3_PATH=examples/tensorflow/license-plate-reader/yolov3_keras
aws s3 cp license_plate.h5 "s3://$BUCKET/$YOLO3_PATH/model.h5"
```

If you would like to host the model from your own bucket, or if you want to fine tune the model for your needs, you can:
#### Full Version

Download the *SavedModel*:

Expand All @@ -101,11 +150,11 @@ Unzip it:
unzip yolov3.zip -d yolov3
```

And then upload it to your bucket (also make sure [cortex.yaml](cortex.yaml) points to this bucket):
And then upload it to your bucket (also make sure [cortex_full.yaml](cortex_full.yaml) points to this bucket):

```bash
BUCKET=my-bucket
YOLO3_PATH=examples/tensorflow/license-plate-reader/yolov3
YOLO3_PATH=examples/tensorflow/license-plate-reader/yolov3_tf
aws s3 cp yolov3/ "s3://$BUCKET/$YOLO3_PATH" --recursive
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
predictor:
type: tensorflow
path: predictor_yolo.py
model: s3://cortex-examples/tensorflow/license-plate-reader/yolov3
model: s3://cortex-examples/tensorflow/license-plate-reader/yolov3_tf
signature_key: serving_default
config:
model_config: config.json
Expand Down
13 changes: 13 additions & 0 deletions examples/tensorflow/license-plate-reader/cortex_lite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# WARNING: you are on the master branch, please refer to the examples on the branch that matches your `cortex version`

- name: license-plate-reader
predictor:
type: python
path: predictor_lite.py
config:
yolov3: s3://cortex-examples/tensorflow/license-plate-reader/yolov3_keras
yolov3_model_config: config.json
compute:
cpu: 1
gpu: 1
mem: 4G
115 changes: 115 additions & 0 deletions examples/tensorflow/license-plate-reader/predictor_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# WARNING: you are on the master branch, please refer to the examples on the branch that matches your `cortex version`

import boto3, base64, cv2, re, os, requests, json
import keras_ocr

from botocore import UNSIGNED
from botocore.client import Config
from tensorflow.keras.models import load_model
import utils.utils as utils
import utils.bbox as bbox_utils
import utils.preprocess as preprocess_utils


class PythonPredictor:
def __init__(self, config):
# download yolov3 model
bucket, key = re.match("s3://(.+?)/(.+)", config["yolov3"]).groups()
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
model_name = "model.h5"
s3.download_file(bucket, os.path.join(key, model_name), model_name)

# load yolov3 model
self.yolov3_model = load_model(model_name)

# get configuration for yolov3 model
with open(config["yolov3_model_config"]) as json_file:
data = json.load(json_file)
for key in data:
setattr(self, key, data[key])
self.box_confidence_score = 0.8

# keras-ocr automatically downloads the pretrained
# weights for the detector and recognizer
self.recognition_model_pipeline = keras_ocr.pipeline.Pipeline()

def predict(self, payload):
# download image
img_url = payload["url"]
image = preprocess_utils.get_url_image(img_url)

# detect the bounding boxes
boxes = utils.get_yolo_boxes(
self.yolov3_model,
image,
self.net_h,
self.net_w,
self.anchors,
self.obj_thresh,
self.nms_thresh,
len(self.labels),
tensorflow_model=False,
)

# purge bounding boxes with a low confidence score
aux = []
for b in boxes:
label = -1
for i in range(len(b.classes)):
if b.classes[i] > self.box_confidence_score:
label = i
if label >= 0:
aux.append(b)
boxes = aux
del aux

# if bounding boxes have been detected
dec_words = []
if len(boxes) > 0:
# create set of images of the detected license plates
lps = []
for b in boxes:
lp = image[b.ymin : b.ymax, b.xmin : b.xmax]
lps.append(lp)

# run batch inference
try:
prediction_groups = self.recognition_model_pipeline.recognize(lps)
except ValueError:
# exception can occur when the images are too small
prediction_groups = []

# process pipeline output
image_list = []
for img_predictions in prediction_groups:
boxes_per_image = []
for predictions in img_predictions:
boxes_per_image.append([predictions[0], predictions[1].tolist()])
image_list.append(boxes_per_image)

# reorder text within detected LPs based on horizontal position
dec_lps = preprocess_utils.reorder_recognized_words(image_list)
for dec_lp in dec_lps:
dec_words.append([word[0] for word in dec_lp])

# if there are no recognized LPs, then don't draw them
if len(dec_words) == 0:
dec_words = [[] for i in range(len(boxes))]

# draw predictions as overlays on the source image
draw_image = bbox_utils.draw_boxes(
image,
boxes,
overlay_text=dec_words,
labels=["LP"],
obj_thresh=self.box_confidence_score,
)

# image represented in bytes
byte_im = preprocess_utils.image_to_jpeg_bytes(draw_image)

# encode image
image_enc = base64.b64encode(byte_im).decode("utf-8")

# image with draw boxes overlayed
return image_enc
7 changes: 3 additions & 4 deletions examples/tensorflow/license-plate-reader/predictor_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import numpy as np
import cv2
import pickle
from utils.utils import get_yolo_boxes
from utils.bbox import BoundBox
import utils.utils as utils


class TensorFlowPredictor:
Expand All @@ -26,9 +25,9 @@ def predict(self, payload):
image = cv2.imdecode(jpg_as_np, flags=cv2.IMREAD_COLOR)

# detect the bounding boxes
boxes = get_yolo_boxes(
boxes = utils.get_yolo_boxes(
self.client,
[image],
image,
self.net_h,
self.net_w,
self.anchors,
Expand Down
70 changes: 8 additions & 62 deletions examples/tensorflow/license-plate-reader/sample_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,8 @@

import click, cv2, requests, pickle, base64, json
import numpy as np
from utils.bbox import BoundBox, draw_boxes
from statistics import mean


def image_to_jpeg_nparray(image, quality=[int(cv2.IMWRITE_JPEG_QUALITY), 95]):
"""
Convert numpy image to jpeg numpy vector.
"""
is_success, im_buf_arr = cv2.imencode(".jpg", image, quality)
return im_buf_arr


def image_to_jpeg_bytes(image, quality=[int(cv2.IMWRITE_JPEG_QUALITY), 95]):
"""
Convert numpy image to bytes-encoded jpeg image.
"""
buf = image_to_jpeg_nparray(image, quality)
byte_im = buf.tobytes()
return byte_im


def get_url_image(url_image):
"""
Get numpy image from URL image.
"""
resp = requests.get(url_image, stream=True).raw
image = np.asarray(bytearray(resp.read()), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image


def reorder_recognized_words(detected_images):
"""
Reorder the detected words in each image based on the average horizontal position of each word.
Sorting them in ascending order.
"""

reordered_images = []
for detected_image in detected_images:

# computing the mean average position for each word
mean_horizontal_positions = []
for words in detected_image:
box = words[1]
y_positions = [point[0] for point in box]
mean_y_position = mean(y_positions)
mean_horizontal_positions.append(mean_y_position)
indexes = np.argsort(mean_horizontal_positions)

# and reordering them
reordered = []
for index, words in zip(indexes, detected_image):
reordered.append(detected_image[index])
reordered_images.append(reordered)

return reordered_images
import utils.bbox as bbox_utils
import utils.preprocess as preprocess_utils


@click.command(
Expand All @@ -81,8 +27,8 @@ def reorder_recognized_words(detected_images):
def main(img_url_src, yolov3_endpoint, crnn_endpoint, output):

# get the image in bytes representation
image = get_url_image(img_url_src)
image_bytes = image_to_jpeg_bytes(image)
image = preprocess_utils.get_url_image(img_url_src)
image_bytes = preprocess_utils.image_to_jpeg_bytes(image)

# encode image
image_enc = base64.b64encode(image_bytes).decode("utf-8")
Expand All @@ -97,7 +43,7 @@ def main(img_url_src, yolov3_endpoint, crnn_endpoint, output):
boxes_raw = resp.json()["boxes"]
boxes = []
for b in boxes_raw:
box = BoundBox(*b)
box = bbox_utils.BoundBox(*b)
boxes.append(box)

# purge bounding boxes with a low confidence score
Expand All @@ -119,7 +65,7 @@ def main(img_url_src, yolov3_endpoint, crnn_endpoint, output):
lps = []
for b in boxes:
lp = image[b.ymin : b.ymax, b.xmin : b.xmax]
jpeg = image_to_jpeg_nparray(lp)
jpeg = preprocess_utils.image_to_jpeg_nparray(lp)
lps.append(jpeg)

# encode the cropped license plates
Expand All @@ -134,15 +80,15 @@ def main(img_url_src, yolov3_endpoint, crnn_endpoint, output):

# parse the response
dec_lps = resp.json()["license-plates"]
dec_lps = reorder_recognized_words(dec_lps)
dec_lps = preprocess_utils.reorder_recognized_words(dec_lps)
for dec_lp in dec_lps:
dec_words.append([word[0] for word in dec_lp])

if len(dec_words) == 0:
dec_words = [[] for i in range(len(boxes))]

# draw predictions as overlays on the source image
draw_image = draw_boxes(
draw_image = bbox_utils.draw_boxes(
image, boxes, overlay_text=dec_words, labels=["LP"], obj_thresh=confidence_score
)

Expand Down
1 change: 1 addition & 0 deletions examples/tensorflow/license-plate-reader/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# WARNING: you are on the master branch, please refer to the examples on the branch that matches your `cortex version`
Loading