Skip to content

Commit 9a53d9f

Browse files
fix(sdk): Fix ensembler ensemble method (caraml-dev#413)
* Fix turing ensembler sdk * Add sdk tests * Fix broken tests
1 parent 86a4db5 commit 9a53d9f

File tree

3 files changed

+134
-16
lines changed

3 files changed

+134
-16
lines changed

sdk/tests/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import date, datetime
22
from dateutil.tz import tzutc
3+
from typing import Optional, Union, Any, Dict
34
from turing import generated as client
45
import turing.ensembler
56

@@ -17,9 +18,8 @@ def utc_date(date_str: str):
1718
return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=tzutc())
1819

1920

20-
class MyTestEnsembler(turing.ensembler.PyFunc):
21+
class MyTestEnsemblerJob(turing.ensembler.PyFunc):
2122
import pandas
22-
from typing import Any, Optional
2323

2424
def __init__(self, default: float):
2525
self._default = default
@@ -29,11 +29,26 @@ def initialize(self, artifacts: dict):
2929

3030
def ensemble(
3131
self,
32-
input: pandas.Series,
33-
predictions: pandas.Series,
34-
treatment_config: Optional[dict],
32+
input: Union[pandas.Series, Dict[str, Any]],
33+
predictions: Union[pandas.Series, Dict[str, Any]],
34+
**kwargs,
3535
) -> Any:
3636
if input["treatment"] in predictions:
3737
return predictions[input["treatment"]]
3838
else:
3939
return self._default
40+
41+
42+
class MyTestEnsemblerService(turing.ensembler.PyFunc):
43+
def __init__(self, default: float):
44+
self._default = default
45+
46+
def initialize(self, artifacts: dict):
47+
pass
48+
49+
def ensemble(
50+
self,
51+
enricher_response,
52+
**kwargs,
53+
) -> Any:
54+
return enricher_response

sdk/tests/ensembler_test.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
with open(os.path.join(data_dir, "list_jobs_0000.json")) as f:
1717
list_jobs_0000 = f.read()
1818

19-
def test_predict():
19+
20+
def test_ensemble_job_predict():
2021
default_value = random.random()
21-
ensembler = tests.MyTestEnsembler(default_value)
22+
ensembler = tests.MyTestEnsemblerJob(default_value)
2223

2324
model_input = pandas.DataFrame(
2425
data={
@@ -43,6 +44,98 @@ def test_predict():
4344

4445
assert_series_equal(expected, result)
4546

47+
48+
def test_ensemble_service_predict_with_enricher_response():
49+
default_value = random.random()
50+
ensembler = tests.MyTestEnsemblerService(default_value)
51+
52+
model_input = {
53+
"headers": {
54+
"Host": "example.com",
55+
"Referer": "https://example.com/",
56+
},
57+
"body": {
58+
# original request payload unmodified
59+
"request": {
60+
"dummy": "dummy"
61+
},
62+
"response": {
63+
"route_responses": [
64+
{
65+
"route": "control",
66+
"data": {
67+
"abc": 1,
68+
},
69+
},
70+
{
71+
"route": "xgboost-ordinal",
72+
"data": {
73+
"def": 2,
74+
}
75+
},
76+
],
77+
"experiment": {
78+
"configuration": {
79+
"output": "xyz",
80+
},
81+
},
82+
"enricher_response": {
83+
"output": "meow"
84+
}
85+
},
86+
},
87+
}
88+
89+
expected = {"output": "meow"}
90+
result = ensembler.predict(context=None, model_input=model_input)
91+
92+
assert expected == result
93+
94+
95+
def test_ensemble_service_predict_without_enricher_response():
96+
default_value = random.random()
97+
ensembler = tests.MyTestEnsemblerService(default_value)
98+
99+
model_input = {
100+
"headers": {
101+
"Host": "example.com",
102+
"Referer": "https://example.com/",
103+
},
104+
"body": {
105+
# original request payload unmodified
106+
"request": {
107+
"dummy": "dummy"
108+
},
109+
"response": {
110+
"route_responses": [
111+
{
112+
"route": "control",
113+
"data": {
114+
"abc": 1,
115+
},
116+
},
117+
{
118+
"route": "xgboost-ordinal",
119+
"data": {
120+
"def": 2,
121+
}
122+
},
123+
],
124+
"experiment": {
125+
"configuration": {
126+
"output": "xyz",
127+
},
128+
},
129+
},
130+
},
131+
}
132+
133+
expected = None
134+
result = ensembler.predict(context=None, model_input=model_input)
135+
136+
assert expected == result
137+
138+
46139
@pytest.mark.parametrize("num_ensemblers", [6])
47140
def test_list_ensemblers(
48141
turing_api, project, generic_ensemblers, use_google_oauth, active_project_magic_mock
@@ -73,6 +166,7 @@ def test_list_ensemblers(
73166
for actual, expected in zip(actual, generic_ensemblers):
74167
assert actual == turing.PyFuncEnsembler.from_open_api(expected)
75168

169+
76170
@patch("google.cloud.storage.Client")
77171
@patch("requests.Session.request")
78172
@patch("urllib3.PoolManager.request")
@@ -103,15 +197,16 @@ def test_create_ensembler(
103197

104198
actual = turing.PyFuncEnsembler.create(
105199
name=pyfunc_ensembler.name,
106-
ensembler_instance=tests.MyTestEnsembler(0.01),
200+
ensembler_instance=tests.MyTestEnsemblerJob(0.01),
107201
conda_env={
108202
"channels": ["defaults"],
109203
"dependencies": ["python=3.9.0", {"pip": ["test-lib==0.0.1"]}],
110204
},
111205
)
112206

113207
assert actual == turing.PyFuncEnsembler.from_open_api(pyfunc_ensembler)
114-
208+
209+
115210
@pytest.mark.parametrize(("num_ensemblers", "ensembler_name"), [(3, "updated")])
116211
@patch("google.cloud.storage.Client")
117212
@patch("requests.Session.request")
@@ -177,7 +272,7 @@ def test_update_ensembler(
177272

178273
actual.update(
179274
name=pyfunc_ensembler.name,
180-
ensembler_instance=tests.MyTestEnsembler(0.06),
275+
ensembler_instance=tests.MyTestEnsemblerJob(0.06),
181276
conda_env={
182277
"channels": ["defaults"],
183278
"dependencies": ["python>=3.8.0", {"pip": ["test-lib==0.0.1"]}],
@@ -190,7 +285,8 @@ def test_update_ensembler(
190285
)
191286

192287
assert actual == turing.PyFuncEnsembler.from_open_api(pyfunc_ensembler)
193-
288+
289+
194290
@pytest.mark.parametrize(("num_ensemblers", "ensembler_name"), [(3, "updated")])
195291
def test_update_ensembler_existing_router_version(
196292
turing_api,
@@ -235,7 +331,7 @@ def test_update_ensembler_existing_router_version(
235331
with pytest.raises(ValueError) as error:
236332
actual.update(
237333
name=pyfunc_ensembler.name,
238-
ensembler_instance=tests.MyTestEnsembler(0.06),
334+
ensembler_instance=tests.MyTestEnsemblerJob(0.06),
239335
conda_env={
240336
"channels": ["defaults"],
241337
"dependencies": ["python>=3.8.0", {"pip": ["test-lib==0.0.1"]}],
@@ -251,7 +347,8 @@ def test_update_ensembler_existing_router_version(
251347
expected_error_message = "There is pending router version using this ensembler. Please wait for the router version to be deployed or undeploy it, before updating the ensembler."
252348
actual_error_message = str(error.value)
253349
assert expected_error_message == actual_error_message
254-
350+
351+
255352
@pytest.mark.parametrize(("num_ensemblers", "ensembler_name"), [(3, "updated")])
256353
def test_update_ensembler_existing_job(
257354
turing_api, project, generic_ensemblers, pyfunc_ensembler, use_google_oauth, active_project_magic_mock
@@ -297,7 +394,7 @@ def test_update_ensembler_existing_job(
297394
with pytest.raises(ValueError) as error:
298395
actual.update(
299396
name=pyfunc_ensembler.name,
300-
ensembler_instance=tests.MyTestEnsembler(0.06),
397+
ensembler_instance=tests.MyTestEnsemblerJob(0.06),
301398
conda_env={
302399
"channels": ["defaults"],
303400
"dependencies": ["python>=3.8.0", {"pip": ["test-lib==0.0.1"]}],
@@ -313,7 +410,8 @@ def test_update_ensembler_existing_job(
313410
expected_error_message = "There is pending ensembling job using this ensembler. Please wait for the ensembling job to be completed or terminate it, before updating the ensembler."
314411
actual_error_message = str(error.value)
315412
assert expected_error_message == actual_error_message
316-
413+
414+
317415
@patch("google.cloud.storage.Client")
318416
@patch("requests.Session.request")
319417
@patch("urllib3.PoolManager.request")

sdk/turing/ensembler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,17 @@ def _ensemble_request(self, model_input: Dict[str, Any]) -> Any:
113113
# Deletes route from the dictionary as it is a duplicate of the key
114114
del routes_to_response[prediction["route"]]["route"]
115115

116+
# This is for older Turing routers which do not pass the enricher_response to the ensembler
117+
enricher_response = None
118+
if "enricher_response" in request_body["response"]:
119+
enricher_response = request_body["response"]["enricher_response"]
120+
116121
try:
117122
return self.ensemble(
118123
input=request_body["request"],
119124
predictions=routes_to_response,
120125
treatment_config=request_body["response"]["experiment"],
121-
enricher_response=request_body["response"]["enricher_response"],
126+
enricher_response=enricher_response,
122127
headers=model_input["headers"],
123128
)
124129
except TypeError as e:

0 commit comments

Comments
 (0)