forked from svpino/ml.school
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_inference.py
207 lines (153 loc) · 6.29 KB
/
test_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import Mock
import numpy as np
import pandas as pd
import pytest
from pipelines.inference import Model
@pytest.fixture
def mock_keras_model(monkeypatch):
"""Return a mock Keras model."""
mock_model = Mock()
mock_model.predict = Mock(return_value=np.array([[0.6, 0.3, 0.1]]))
monkeypatch.setattr("keras.saving.load_model", lambda _: mock_model)
return mock_model
@pytest.fixture
def mock_transformers(monkeypatch):
"""Return mock transformer instances."""
mock_features_transformer = Mock()
mock_features_transformer.transform = Mock()
mock_species_transformer = Mock()
mock_species_transformer.categories_ = [["Adelie", "Chinstrap", "Gentoo"]]
mock_target_transformer = Mock()
mock_target_transformer.named_transformers_ = {"species": mock_species_transformer}
def mock_load(path):
return (
mock_features_transformer
if path == "features_transformer"
else mock_target_transformer
)
monkeypatch.setattr("joblib.load", mock_load)
return mock_features_transformer, mock_target_transformer
@pytest.fixture
def model(mock_keras_model, mock_transformers):
"""Return a model instance."""
directory = tempfile.mkdtemp()
data_collection_uri = Path(directory) / "database.db"
model = Model(data_collection_uri=data_collection_uri, data_capture=False)
mock_context = Mock()
mock_context.artifacts = {
"model": "model",
"features_transformer": "features_transformer",
"target_transformer": "target_transformer",
}
model.load_context(mock_context)
assert model.model == mock_keras_model
assert model.features_transformer == mock_transformers[0]
assert model.target_transformer == mock_transformers[1]
return model
def fetch_data(model):
connection = sqlite3.connect(model.data_collection_uri)
cursor = connection.cursor()
cursor.execute("SELECT island, prediction, confidence FROM data;")
data = cursor.fetchone()
connection.close()
return data
def test_process_input(model):
model.features_transformer.transform = Mock(
return_value=np.array([[0.1, 0.2]]),
)
input_data = pd.DataFrame({"island": ["Torgersen"]})
result = model.process_input(input_data)
# Ensure the transform method is called with the input data.
model.features_transformer.transform.assert_called_once_with(input_data)
# The function should return the transformed data.
assert np.array_equal(result, np.array([[0.1, 0.2]]))
def test_process_input_return_none_on_exception(model):
model.features_transformer.transform = Mock(side_effect=Exception("Invalid input"))
input_data = pd.DataFrame({"island": ["Torgersen"]})
result = model.process_input(input_data)
# We want to make sure that the transform method is called with the input data.
model.features_transformer.transform.assert_called_once_with(input_data)
# Since there was an exception, the function should return None.
assert result is None
def test_process_output(model):
output = np.array([[0.6, 0.3, 0.1], [0.2, 0.7, 0.1]])
result = model.process_output(output)
assert result == [
{"prediction": "Adelie", "confidence": 0.6},
{"prediction": "Chinstrap", "confidence": 0.7},
]
def test_process_output_return_empty_list_on_none(model):
assert model.process_output(None) == []
def test_predict_return_empty_list_on_invalid_input(model, monkeypatch):
mock_process_input = Mock(return_value=None)
monkeypatch.setattr(model, "process_input", mock_process_input)
input_data = [{"island": "Torgersen", "culmen_length_mm": 39.1}]
result = model.predict(context=None, model_input=input_data)
assert result == []
def test_predict_return_empty_list_on_invalid_prediction(model, monkeypatch):
mock_process_input = Mock(return_value=np.array([[0.1, 0.2, 0.3]]))
model.model.predict = Mock(return_value=None)
monkeypatch.setattr(model, "process_input", mock_process_input)
input_data = [{"island": "Torgersen", "culmen_length_mm": 39.1}]
result = model.predict(context=None, model_input=input_data)
assert result == []
def test_predict(model, monkeypatch):
mock_process_input = Mock(return_value=np.array([[0.1, 0.2, 0.3]]))
mock_process_output = Mock(
return_value=[{"prediction": "Adelie", "confidence": 0.6}],
)
model.model.predict = Mock(return_value=np.array([[0.6, 0.3, 0.1]]))
monkeypatch.setattr(model, "process_input", mock_process_input)
monkeypatch.setattr(model, "process_output", mock_process_output)
input_data = [{"island": "Torgersen", "culmen_length_mm": 39.1}]
result = model.predict(context=None, model_input=input_data)
assert result == [{"prediction": "Adelie", "confidence": 0.6}]
mock_process_input.assert_called_once()
mock_process_output.assert_called_once()
model.model.predict.assert_called_once()
@pytest.mark.parametrize(
("default_data_capture", "request_data_capture", "database_exists"),
[
(False, False, False),
(True, False, False),
(False, True, True),
(True, True, True),
],
)
def test_data_capture(
model,
default_data_capture,
request_data_capture,
database_exists,
):
model.data_capture = default_data_capture
model.predict(
context=None,
model_input=[{"island": "Torgersen"}],
params={"data_capture": request_data_capture},
)
assert Path(model.data_collection_uri).exists() == database_exists
def test_capture_stores_data_in_database(model):
model.predict(
context=None,
model_input=[{"island": "Torgersen"}],
params={"data_capture": True},
)
data = fetch_data(model)
assert data == ("Torgersen", "Adelie", 0.6)
def test_capture_on_invalid_output(model, monkeypatch):
mock_process_output = Mock(return_value=None)
monkeypatch.setattr(model, "process_output", mock_process_output)
model.predict(
context=None,
model_input=[{"island": "Torgersen"}],
params={"data_capture": True},
)
data = fetch_data(model)
# The prediction and confidence columns should be None because the output
# from the model was empty
assert data == ("Torgersen", None, None)