Skip to content

Commit

Permalink
Model container testing function to clipper admin (Vanilla python) (#394
Browse files Browse the repository at this point in the history
)

* Vanilla python model

* PR edits and integration test

* integration test

* deleting extra python files

* Addressed more PR comments: check for input_batch, batch input for integration test

* Ran formatting script

* fix formatting issues

* Fixing import error failing jenkins

* merge issue

* trying to fix import

* Removed import

* Registered app in integration test

* Removing reformatted files

* Fixed style errors

* Fixing integration test

* fixed url

* fixing integration test

* linked model to app

* Playing around with connected model for query

* trying to fix connection issues

* connection issues

* Fixed connection issue in integration test

* Uncommented tests

* retrigger

* retry

* retrigger jenkins

* reverted accidental changes
  • Loading branch information
rohsuresh authored and dcrankshaw committed Mar 13, 2018
1 parent 42a4680 commit bd57920
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 8 deletions.
79 changes: 79 additions & 0 deletions clipper_admin/clipper_admin/clipper_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import os
import tarfile
import six
from cloudpickle import CloudPickler
import pickle
import numpy as np

from .container_manager import CONTAINERLESS_MODEL_IMAGE
from .exceptions import ClipperException, UnconnectedException
Expand Down Expand Up @@ -1187,3 +1190,79 @@ def stop_all(self):
"""
self.cm.stop_all()
logger.info("Stopped all Clipper cluster and all model containers")

def test_predict_function(self, query, func, input_type):
"""Tests that the user's function has the correct signature and can be properly saved and loaded.
The function should take a dict request object like the query frontend expects JSON,
the predict function, and the input type for the model.
For example, the function can be called like: clipper_conn.test_predict_function({"input": [1.0, 2.0, 3.0]}, predict_func, "doubles")
Parameters
----------
query: JSON or list of dicts
Inputs to test the prediction function on.
func: function
Predict function to test.
input_type: str
The input_type to be associated with the registered app and deployed model.
One of "integers", "floats", "doubles", "bytes", or "strings".
"""
if not self.connected:
self.connect()
query_data = list(x for x in list(query.values()))
query_key = list(query.keys())

if query_key[0] == "input_batch":
query_data = query_data[0]

try:
flattened_data = [
item for sublist in query_data for item in sublist
]
except TypeError as e:
return "Invalid input type or JSON key"

numpy_data = None

if input_type == "bytes":
numpy_data = list(np.int8(x) for x in query_data)
for x in flattened_data:
if type(x) != bytes:
return "Invalid input type"

if input_type == "integers":
numpy_data = list(np.int32(x) for x in query_data)
for x in flattened_data:
if type(x) != int:
return "Invalid input type"

if input_type == "floats" or input_type == "doubles":
if input_type == "floats":
numpy_data = list(np.float32(x) for x in query_data)
else:
numpy_data = list(np.float64(x) for x in query_data)
for x in flattened_data:
if type(x) != float:
return "Invalid input type"

if input_type == "string":
numpy_data = list(np.str_(x) for x in query_data)
for x in flattened_data:
if type(x) != str:
return "Invalid input type"

s = six.StringIO()
c = CloudPickler(s, 2)
c.dump(func)
serialized_func = s.getvalue()
reloaded_func = pickle.loads(serialized_func)

try:
assert reloaded_func
except AssertionError as e:
logger.error("Function does not properly serialize and reload")
return "Function does not properly serialize and reload"

return reloaded_func(numpy_data)
67 changes: 59 additions & 8 deletions integration-tests/clipper_admin_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import json
import time
import numpy as np
import requests
import tempfile
import shutil
Expand Down Expand Up @@ -343,6 +344,59 @@ def predict_func(inputs):
})
self.assertEqual(len(containers), 1)

def test_test_predict_function(self):
def predict_func(xs):
return [sum(x) for x in xs]

self.clipper_conn.register_application(
name="hello-world",
input_type="doubles",
default_output="-1.0",
slo_micros=100000)

deploy_python_closure(
self.clipper_conn,
name="sum-model",
version=1,
input_type="doubles",
func=predict_func)
self.clipper_conn.link_model_to_app(
app_name="hello-world", model_name="sum-model")
time.sleep(60)

addr = self.clipper_conn.get_query_addr()
url = "http://{addr}/hello-world/predict".format(
addr=addr, app='hello-world')

headers = {"Content-type": "application/json"}
test_input = [1.1, 2.2, 3.3]
pred = requests.post(
url, headers=headers, data=json.dumps({
"input": test_input
})).json()
test_predict_result = self.clipper_conn.test_predict_function(
query={"input": test_input},
func=predict_func,
input_type="doubles")
self.assertEqual([pred['output']],
test_predict_result) # tests single input

test_batch_input = [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]]
batch_pred = requests.post(
url,
headers=headers,
data=json.dumps({
"input_batch": test_batch_input
})).json()
test_batch_predict_result = self.clipper_conn.test_predict_function(
query={"input_batch": test_batch_input},
func=predict_func,
input_type="doubles")
batch_predictions = batch_pred['batch_predictions']
batch_pred_outputs = [batch['output'] for batch in batch_predictions]
self.assertEqual(batch_pred_outputs,
test_batch_predict_result) # tests batch input


class ClipperManagerTestCaseLong(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -520,23 +574,20 @@ def predict_func(inputs):


SHORT_TEST_ORDERING = [
'test_register_model_correct',
'test_register_application_correct',
'test_register_model_correct', 'test_register_application_correct',
'test_link_not_registered_model_to_app_fails',
'test_get_model_links_when_none_exist_returns_empty_list',
'test_link_registered_model_to_app_succeeds',
'get_app_info_for_registered_app_returns_info_dictionary',
'get_app_info_for_nonexistent_app_returns_none',
'test_set_num_replicas_for_external_model_fails',
'test_model_version_sets_correctly',
'test_get_logs_creates_log_files',
'test_model_version_sets_correctly', 'test_get_logs_creates_log_files',
'test_inspect_instance_returns_json_dict',
'test_model_deploys_successfully',
'test_set_num_replicas_for_deployed_model_succeeds',
'test_remove_inactive_containers_succeeds',
'test_stop_models',
'test_python_closure_deploys_successfully',
'test_register_py_endpoint',
'test_remove_inactive_containers_succeeds', 'test_stop_models',
'test_python_closure_deploys_successfully', 'test_register_py_endpoint',
'test_test_predict_function'
]

LONG_TEST_ORDERING = [
Expand Down

0 comments on commit bd57920

Please sign in to comment.