Skip to content

Commit a65cd51

Browse files
committed
add tfserving client
1 parent 1f9684d commit a65cd51

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed

TFserving-benchmark-imagenet.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed May 15 15:51:42 2019
4+
5+
@author: vinhngx
6+
"""
7+
8+
9+
#!/usr/bin/env python
10+
from __future__ import print_function
11+
import argparse
12+
import numpy as np
13+
import time
14+
import pdb
15+
16+
from multiprocessing import Process, cpu_count
17+
18+
import cv2
19+
import os
20+
import tensorflow as tf
21+
import logging
22+
23+
logging.getLogger("tensorflow").setLevel(logging.ERROR)
24+
25+
from grpc.beta import implementations
26+
from tensorflow_serving.apis import predict_pb2
27+
from tensorflow_serving.apis import prediction_service_pb2
28+
29+
parser = argparse.ArgumentParser(description='incetion grpc client flags.')
30+
parser.add_argument('--host', default='localhost', help='inception serving host')
31+
parser.add_argument('--port', default='8500', help='inception serving port')
32+
parser.add_argument('--image', default='/code/data/img.png', help='path to JPEG image file')
33+
FLAGS = parser.parse_args()
34+
35+
def deserialize_image_record(record):
36+
feature_map = {
37+
'image/encoded': tf.FixedLenFeature([ ], tf.string, ''),
38+
'image/class/label': tf.FixedLenFeature([1], tf.int64, -1),
39+
'image/class/text': tf.FixedLenFeature([ ], tf.string, ''),
40+
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
41+
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
42+
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
43+
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32)
44+
}
45+
with tf.name_scope('deserialize_image_record'):
46+
obj = tf.parse_single_example(record, feature_map)
47+
imgdata = obj['image/encoded']
48+
label = tf.cast(obj['image/class/label'], tf.int32)
49+
bbox = tf.stack([obj['image/object/bbox/%s'%x].values
50+
for x in ['ymin', 'xmin', 'ymax', 'xmax']])
51+
bbox = tf.transpose(tf.expand_dims(bbox, 0), [0,2,1])
52+
text = obj['image/class/text']
53+
return imgdata, label, bbox, text
54+
55+
VALIDATION_DATA_DIR = "/data"
56+
BATCH_SIZE = 8
57+
58+
def get_files(data_dir, filename_pattern):
59+
if data_dir == None:
60+
return []
61+
files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern))
62+
if files == []:
63+
raise ValueError('Can not find any files in {} with '
64+
'pattern "{}"'.format(data_dir, filename_pattern))
65+
return files
66+
67+
calibration_files = get_files(VALIDATION_DATA_DIR, 'validation*')
68+
69+
print('There are %d calibration files. \n%s\n%s\n...'%(len(calibration_files), calibration_files[0], calibration_files[-1]))
70+
import vgg_preprocessing
71+
def preprocess(record):
72+
# Parse TFRecord
73+
imgdata, label, bbox, text = deserialize_image_record(record)
74+
label -= 1 # Change to 0-based (don't use background class)
75+
try: image = tf.image.decode_jpeg(imgdata, channels=3, fancy_upscaling=False, dct_method='INTEGER_FAST')
76+
except: image = tf.image.decode_png(imgdata, channels=3)
77+
78+
image = vgg_preprocessing.preprocess_image(image, 224, 224, is_training=False)
79+
return image, label
80+
81+
dataset = tf.data.TFRecordDataset(calibration_files)
82+
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess, batch_size=BATCH_SIZE, num_parallel_calls=8))
83+
84+
85+
def main():
86+
# create prediction service client stubpython
87+
channel = implementations.insecure_channel(FLAGS.host, int(FLAGS.port))
88+
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
89+
90+
# create request
91+
request = predict_pb2.PredictRequest()
92+
request.model_spec.name = 'resnet'
93+
request.model_spec.signature_name = 'serving_default'
94+
95+
start_time = time.time()
96+
with tf.Session(graph=tf.Graph()) as sess:
97+
# prepare dataset iterator
98+
iterator = dataset.make_one_shot_iterator()
99+
next_element = iterator.get_next()
100+
101+
num_hits = 0
102+
num_predict = 0
103+
try:
104+
while True:
105+
image_data = sess.run(next_element)
106+
img = image_data[0]
107+
label = image_data[1].squeeze()
108+
109+
# convert to tensor proto and make request
110+
# shape is in NHWC (num_samples x height x width x channels) format
111+
tensor = tf.contrib.util.make_tensor_proto(img, shape=list(img.shape))
112+
request.inputs['input'].CopyFrom(tensor)
113+
resp = stub.Predict(request, 30.0) #timeout
114+
#print("Response", resp)
115+
116+
prediction = tf.make_ndarray(resp.outputs['classes'])
117+
num_hits += np.sum(prediction == label)
118+
num_predict += len(prediction)
119+
except tf.errors.OutOfRangeError as e:
120+
pass
121+
122+
print('Accuracy: %.2f%%'%(100*num_hits/num_predict))
123+
print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))
124+
125+
def run_benchmark(filelist, id, perf_list):
126+
dataset = tf.data.TFRecordDataset(filelist)
127+
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess, batch_size=BATCH_SIZE, num_parallel_calls=8))
128+
129+
# create prediction service client stubpython
130+
channel = implementations.insecure_channel(FLAGS.host, int(FLAGS.port))
131+
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
132+
133+
# create request
134+
request = predict_pb2.PredictRequest()
135+
request.model_spec.name = 'resnet'
136+
request.model_spec.signature_name = 'serving_default'
137+
138+
with tf.Session(graph=tf.Graph()) as sess:
139+
# prepare dataset iterator
140+
iterator = dataset.make_one_shot_iterator()
141+
next_element = iterator.get_next()
142+
143+
num_hits = 0
144+
num_predict = 0
145+
try:
146+
while True:
147+
image_data = sess.run(next_element)
148+
img = image_data[0]
149+
label = image_data[1].squeeze()
150+
151+
# convert to tensor proto and make request
152+
# shape is in NHWC (num_samples x height x width x channels) format
153+
tensor = tf.contrib.util.make_tensor_proto(img, shape=list(img.shape))
154+
request.inputs['input'].CopyFrom(tensor)
155+
resp = stub.Predict(request, 30.0) #timeout
156+
#print("Response", resp)
157+
158+
prediction = tf.make_ndarray(resp.outputs['classes'])
159+
num_hits += np.sum(prediction == label)
160+
num_predict += len(prediction)
161+
except tf.errors.OutOfRangeError as e:
162+
pass
163+
print('Thread %d of %d done' %(id, len(perf_list)) )
164+
perf_list[id] = (num_hits, num_predict)
165+
print("Thread %d performance: "%id, perf_list)
166+
167+
from multiprocessing.managers import BaseManager, DictProxy
168+
def main_parallel():
169+
170+
NUM_JOBS = 8
171+
print ('Benchmarking with %d threads...'%NUM_JOBS)
172+
173+
total = len(calibration_files)
174+
chunk_size = total // NUM_JOBS + 1
175+
176+
BaseManager.register('dict', dict, DictProxy)
177+
manager = BaseManager()
178+
manager.start()
179+
perf_list = manager.dict()
180+
181+
processes = []
182+
start_time = time.time()
183+
184+
id = 0
185+
for i in range(0, total, chunk_size):
186+
print('Thread %d of %d start' %(id, len(perf_list)) )
187+
proc = Process(
188+
target=run_benchmark,
189+
args=[
190+
calibration_files[i:i+chunk_size],
191+
id,
192+
perf_list
193+
]
194+
)
195+
id += 1
196+
processes.append(proc)
197+
for proc in processes:
198+
proc.start()
199+
for proc in processes:
200+
proc.join()
201+
202+
print("Thread performance: ", perf_list)
203+
num_hits = 0
204+
num_predict = 0
205+
for key, entry in perf_list.items():
206+
num_hits += entry[0]
207+
num_predict += entry[1]
208+
209+
print('Total samples: %d'%num_predict)
210+
print('Accuracy: %.2f%%'%(100*num_hits/num_predict))
211+
print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))
212+
213+
if __name__ == '__main__':
214+
#main()
215+
main_parallel()

0 commit comments

Comments
 (0)