1616with open (os .path .join (data_dir , "list_jobs_0000.json" )) as f :
1717 list_jobs_0000 = f .read ()
1818
19- def test_predict ():
19+
20+ def test_ensemble_job_predict ():
2021 default_value = random .random ()
21- ensembler = tests .MyTestEnsembler (default_value )
22+ ensembler = tests .MyTestEnsemblerJob (default_value )
2223
2324 model_input = pandas .DataFrame (
2425 data = {
@@ -43,6 +44,98 @@ def test_predict():
4344
4445 assert_series_equal (expected , result )
4546
47+
48+ def test_ensemble_service_predict_with_enricher_response ():
49+ default_value = random .random ()
50+ ensembler = tests .MyTestEnsemblerService (default_value )
51+
52+ model_input = {
53+ "headers" : {
54+ "Host" : "example.com" ,
55+ "Referer" : "https://example.com/" ,
56+ },
57+ "body" : {
58+ # original request payload unmodified
59+ "request" : {
60+ "dummy" : "dummy"
61+ },
62+ "response" : {
63+ "route_responses" : [
64+ {
65+ "route" : "control" ,
66+ "data" : {
67+ "abc" : 1 ,
68+ },
69+ },
70+ {
71+ "route" : "xgboost-ordinal" ,
72+ "data" : {
73+ "def" : 2 ,
74+ }
75+ },
76+ ],
77+ "experiment" : {
78+ "configuration" : {
79+ "output" : "xyz" ,
80+ },
81+ },
82+ "enricher_response" : {
83+ "output" : "meow"
84+ }
85+ },
86+ },
87+ }
88+
89+ expected = {"output" : "meow" }
90+ result = ensembler .predict (context = None , model_input = model_input )
91+
92+ assert expected == result
93+
94+
95+ def test_ensemble_service_predict_without_enricher_response ():
96+ default_value = random .random ()
97+ ensembler = tests .MyTestEnsemblerService (default_value )
98+
99+ model_input = {
100+ "headers" : {
101+ "Host" : "example.com" ,
102+ "Referer" : "https://example.com/" ,
103+ },
104+ "body" : {
105+ # original request payload unmodified
106+ "request" : {
107+ "dummy" : "dummy"
108+ },
109+ "response" : {
110+ "route_responses" : [
111+ {
112+ "route" : "control" ,
113+ "data" : {
114+ "abc" : 1 ,
115+ },
116+ },
117+ {
118+ "route" : "xgboost-ordinal" ,
119+ "data" : {
120+ "def" : 2 ,
121+ }
122+ },
123+ ],
124+ "experiment" : {
125+ "configuration" : {
126+ "output" : "xyz" ,
127+ },
128+ },
129+ },
130+ },
131+ }
132+
133+ expected = None
134+ result = ensembler .predict (context = None , model_input = model_input )
135+
136+ assert expected == result
137+
138+
46139@pytest .mark .parametrize ("num_ensemblers" , [6 ])
47140def test_list_ensemblers (
48141 turing_api , project , generic_ensemblers , use_google_oauth , active_project_magic_mock
@@ -73,6 +166,7 @@ def test_list_ensemblers(
73166 for actual , expected in zip (actual , generic_ensemblers ):
74167 assert actual == turing .PyFuncEnsembler .from_open_api (expected )
75168
169+
76170@patch ("google.cloud.storage.Client" )
77171@patch ("requests.Session.request" )
78172@patch ("urllib3.PoolManager.request" )
@@ -103,15 +197,16 @@ def test_create_ensembler(
103197
104198 actual = turing .PyFuncEnsembler .create (
105199 name = pyfunc_ensembler .name ,
106- ensembler_instance = tests .MyTestEnsembler (0.01 ),
200+ ensembler_instance = tests .MyTestEnsemblerJob (0.01 ),
107201 conda_env = {
108202 "channels" : ["defaults" ],
109203 "dependencies" : ["python=3.9.0" , {"pip" : ["test-lib==0.0.1" ]}],
110204 },
111205 )
112206
113207 assert actual == turing .PyFuncEnsembler .from_open_api (pyfunc_ensembler )
114-
208+
209+
115210@pytest .mark .parametrize (("num_ensemblers" , "ensembler_name" ), [(3 , "updated" )])
116211@patch ("google.cloud.storage.Client" )
117212@patch ("requests.Session.request" )
@@ -177,7 +272,7 @@ def test_update_ensembler(
177272
178273 actual .update (
179274 name = pyfunc_ensembler .name ,
180- ensembler_instance = tests .MyTestEnsembler (0.06 ),
275+ ensembler_instance = tests .MyTestEnsemblerJob (0.06 ),
181276 conda_env = {
182277 "channels" : ["defaults" ],
183278 "dependencies" : ["python>=3.8.0" , {"pip" : ["test-lib==0.0.1" ]}],
@@ -190,7 +285,8 @@ def test_update_ensembler(
190285 )
191286
192287 assert actual == turing .PyFuncEnsembler .from_open_api (pyfunc_ensembler )
193-
288+
289+
194290@pytest .mark .parametrize (("num_ensemblers" , "ensembler_name" ), [(3 , "updated" )])
195291def test_update_ensembler_existing_router_version (
196292 turing_api ,
@@ -235,7 +331,7 @@ def test_update_ensembler_existing_router_version(
235331 with pytest .raises (ValueError ) as error :
236332 actual .update (
237333 name = pyfunc_ensembler .name ,
238- ensembler_instance = tests .MyTestEnsembler (0.06 ),
334+ ensembler_instance = tests .MyTestEnsemblerJob (0.06 ),
239335 conda_env = {
240336 "channels" : ["defaults" ],
241337 "dependencies" : ["python>=3.8.0" , {"pip" : ["test-lib==0.0.1" ]}],
@@ -251,7 +347,8 @@ def test_update_ensembler_existing_router_version(
251347 expected_error_message = "There is pending router version using this ensembler. Please wait for the router version to be deployed or undeploy it, before updating the ensembler."
252348 actual_error_message = str (error .value )
253349 assert expected_error_message == actual_error_message
254-
350+
351+
255352@pytest .mark .parametrize (("num_ensemblers" , "ensembler_name" ), [(3 , "updated" )])
256353def test_update_ensembler_existing_job (
257354 turing_api , project , generic_ensemblers , pyfunc_ensembler , use_google_oauth , active_project_magic_mock
@@ -297,7 +394,7 @@ def test_update_ensembler_existing_job(
297394 with pytest .raises (ValueError ) as error :
298395 actual .update (
299396 name = pyfunc_ensembler .name ,
300- ensembler_instance = tests .MyTestEnsembler (0.06 ),
397+ ensembler_instance = tests .MyTestEnsemblerJob (0.06 ),
301398 conda_env = {
302399 "channels" : ["defaults" ],
303400 "dependencies" : ["python>=3.8.0" , {"pip" : ["test-lib==0.0.1" ]}],
@@ -313,7 +410,8 @@ def test_update_ensembler_existing_job(
313410 expected_error_message = "There is pending ensembling job using this ensembler. Please wait for the ensembling job to be completed or terminate it, before updating the ensembler."
314411 actual_error_message = str (error .value )
315412 assert expected_error_message == actual_error_message
316-
413+
414+
317415@patch ("google.cloud.storage.Client" )
318416@patch ("requests.Session.request" )
319417@patch ("urllib3.PoolManager.request" )
0 commit comments