|
| 1 | +"""A client that performs inferences on a ResNet model using the REST API. |
| 2 | +
|
| 3 | +The client downloads a test image of a cat, queries the server over the REST API |
| 4 | +with the test image repeatedly and measures how long it takes to respond. |
| 5 | +
|
| 6 | +The client expects a TensorFlow Serving ModelServer running a ResNet SavedModel |
| 7 | +from: |
| 8 | +
|
| 9 | +https://github.com/tensorflow/models/tree/master/official/resnet#pre-trained-model |
| 10 | +
|
| 11 | +The SavedModel must be one that can take JPEG images as inputs. |
| 12 | +
|
| 13 | +Typical usage example: |
| 14 | +
|
| 15 | + python client.py <http://host:port> |
| 16 | +""" |
| 17 | + |
| 18 | +import sys |
| 19 | +import json |
| 20 | + |
| 21 | +import io |
| 22 | +from tensorflow.keras.preprocessing import image |
| 23 | +from tensorflow.keras.applications import resnet50 |
| 24 | +from PIL import Image |
| 25 | +import requests |
| 26 | +import numpy as np |
| 27 | + |
| 28 | +# the image URL is the location of the image we should send to the server |
| 29 | +IMAGE_URL = "https://tensorflow.org/images/blogs/serving/cat.jpg" |
| 30 | + |
| 31 | + |
| 32 | +def main(): |
| 33 | + # parse arg |
| 34 | + if len(sys.argv) != 2: |
| 35 | + print("usage: python client.py <http://host:port>") |
| 36 | + sys.exit(1) |
| 37 | + address = sys.argv[1] |
| 38 | + server_url = f"{address}/v1/models/resnet50_neuron:predict" |
| 39 | + |
| 40 | + # download labels |
| 41 | + labels = requests.get( |
| 42 | + "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt" |
| 43 | + ).text.split("\n")[1:] |
| 44 | + |
| 45 | + # download the image |
| 46 | + response = requests.get(IMAGE_URL, stream=True) |
| 47 | + img = Image.open(io.BytesIO(response.content)) |
| 48 | + img = img.resize((224, 224)) |
| 49 | + |
| 50 | + # process the image |
| 51 | + img_arr = image.img_to_array(img) |
| 52 | + img_arr2 = np.expand_dims(img_arr, axis=0) |
| 53 | + img_arr3 = resnet50.preprocess_input(np.repeat(img_arr2, 1, axis=0)) |
| 54 | + img_list = img_arr3.tolist() |
| 55 | + request_payload = {"signature_name": "serving_default", "inputs": img_list} |
| 56 | + |
| 57 | + # send few requests to warm-up the model. |
| 58 | + for _ in range(3): |
| 59 | + response = requests.post( |
| 60 | + server_url, |
| 61 | + data=json.dumps(request_payload), |
| 62 | + headers={"content-type": "application/json"}, |
| 63 | + ) |
| 64 | + response.raise_for_status() |
| 65 | + |
| 66 | + # send few actual requests and report average latency. |
| 67 | + total_time = 0 |
| 68 | + num_requests = 10 |
| 69 | + for _ in range(num_requests): |
| 70 | + response = requests.post( |
| 71 | + server_url, |
| 72 | + data=json.dumps(request_payload), |
| 73 | + headers={"content-type": "application/json"}, |
| 74 | + ) |
| 75 | + response.raise_for_status() |
| 76 | + total_time += response.elapsed.total_seconds() |
| 77 | + |
| 78 | + label_idx = np.argmax(response.json()["outputs"][0]) |
| 79 | + prediction = labels[label_idx] |
| 80 | + |
| 81 | + print( |
| 82 | + "Prediction class: {}, avg latency: {} ms".format( |
| 83 | + prediction, (total_time * 1000) / num_requests |
| 84 | + ) |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + main() |
0 commit comments