Skip to content

Commit 2b9c680

Browse files
committed
Update iris example
1 parent dea174b commit 2b9c680

File tree

11 files changed

+38
-27
lines changed

11 files changed

+38
-27
lines changed

docs/deployments/request-handlers.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def post_inference(prediction, metadata):
4444
```python
4545
import numpy as np
4646

47-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
47+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
4848

4949
def pre_inference(sample, metadata):
5050
# Convert a dictionary of features to a flattened in list in the order expected by the model
@@ -63,7 +63,7 @@ def post_inference(prediction, metadata):
6363
probabilites = prediction[0][0]
6464
predicted_class_id = int(np.argmax(probabilites))
6565
return {
66-
"class_label": iris_labels[predicted_class_id],
66+
"class_label": labels[predicted_class_id],
6767
"class_index": predicted_class_id,
6868
"probabilities": probabilites,
6969
}

examples/iris/cortex.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
- kind: api
55
name: tensorflow
66
model: s3://cortex-examples/iris/tensorflow.zip
7+
request_handler: handlers/tensorflow.py
78

89
- kind: api
910
name: pytorch
1011
model: s3://cortex-examples/iris/pytorch.onnx
1112
request_handler: handlers/pytorch.py
1213

14+
- kind: api
15+
name: keras
16+
model: s3://cortex-examples/iris/keras.onnx
17+
request_handler: handlers/keras.py
18+
1319
- kind: api
1420
name: xgboost
1521
model: s3://cortex-examples/iris/xgboost.onnx
@@ -19,8 +25,3 @@
1925
name: sklearn
2026
model: s3://cortex-examples/iris/sklearn.onnx
2127
request_handler: handlers/sklearn.py
22-
23-
- kind: api
24-
name: keras
25-
model: s3://cortex-examples/iris/keras.onnx
26-
request_handler: handlers/keras.py

examples/iris/handlers/keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
3+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
44

55

66
def pre_inference(sample, metadata):
@@ -18,7 +18,7 @@ def post_inference(prediction, metadata):
1818
probabilites = prediction[0][0]
1919
predicted_class_id = int(np.argmax(probabilites))
2020
return {
21-
"class_label": iris_labels[predicted_class_id],
21+
"class_label": labels[predicted_class_id],
2222
"class_index": predicted_class_id,
2323
"probabilities": probabilites,
2424
}

examples/iris/handlers/pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
3+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
44

55

66
def pre_inference(sample, metadata):
@@ -17,7 +17,7 @@ def pre_inference(sample, metadata):
1717
def post_inference(prediction, metadata):
1818
predicted_class_id = int(np.argmax(prediction[0][0]))
1919
return {
20-
"class_label": iris_labels[predicted_class_id],
20+
"class_label": labels[predicted_class_id],
2121
"class_index": predicted_class_id,
2222
"probabilites": prediction[0][0],
2323
}

examples/iris/handlers/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
logger = get_logger()
88
s3 = boto3.client("s3")
99

10-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
10+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
1111

1212
scalars_obj = s3.get_object(Bucket="cortex-examples", Key="iris/scalars.json")
1313
scalars = json.loads(scalars_obj["Body"].read().decode("utf-8"))
@@ -28,4 +28,4 @@ def pre_inference(sample, metadata):
2828

2929
def post_inference(prediction, metadata):
3030
predicted_class_id = prediction[0][0]
31-
return {"class_label": iris_labels[predicted_class_id], "class_index": predicted_class_id}
31+
return {"class_label": labels[predicted_class_id], "class_index": predicted_class_id}

examples/iris/handlers/tensorflow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
2+
3+
4+
def post_inference(prediction, metadata):
5+
label_index = int(prediction["response"]["class_ids"][0])
6+
return labels[label_index]

examples/iris/handlers/xgboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
3+
labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
44

55

66
def pre_inference(sample, metadata):
@@ -16,4 +16,4 @@ def pre_inference(sample, metadata):
1616

1717
def post_inference(prediction, metadata):
1818
predicted_class_id = prediction[0][0]
19-
return {"class_label": iris_labels[predicted_class_id], "class_index": predicted_class_id}
19+
return {"class_label": labels[predicted_class_id], "class_index": predicted_class_id}

examples/iris/models/keras_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
scores = model.evaluate(X_test, y_test)
2121
print("\n%s: %.2f%%" % (model.metrics_names[1], scores[1] * 100))
2222

23-
# Convert to ONNX model format
2423
onnx_model = keras2onnx.convert_keras(model)
2524
with open("keras.onnx", "wb") as f:
2625
f.write(onnx_model.SerializeToString())

examples/iris/models/pytorch_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def forward(self, X):
5656

5757
print("prediction accuracy {}".format(accuracy_score(test_y.data, predict_y.data)))
5858

59-
# Convert to ONNX model format
6059
placeholder = torch.randn(1, 4)
6160
torch.onnx.export(
6261
model, placeholder, "pytorch.onnx", input_names=["input"], output_names=["species"]

examples/iris/models/sklearn_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414
print("mean:", scaler.mean_)
1515
print("standard deviation:", np.sqrt(scaler.var_))
1616

17-
1817
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
1918

2019
logreg_model = LogisticRegression(solver="lbfgs", multi_class="multinomial")
2120
logreg_model.fit(X_train, y_train)
2221

2322
print("Test data accuracy: {:.2f}".format(logreg_model.score(X_test, y_test)))
2423

25-
# Convert to ONNX model format
2624
onnx_model = convert_sklearn(logreg_model, initial_types=[("input", FloatTensorType([1, 4]))])
2725
with open("sklearn.onnx", "wb") as f:
2826
f.write(onnx_model.SerializeToString())

examples/iris/models/xgboost_model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
1+
import xgboost as xgb
2+
from sklearn.datasets import load_iris
3+
from sklearn.model_selection import train_test_split
4+
from onnxmltools.convert import convert_xgboost
5+
from onnxconverter_common.data_types import FloatTensorType
26

7+
iris = load_iris()
8+
X, y = iris.data, iris.target
9+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
310

4-
def post_inference(prediction, metadata):
5-
predicted_class_id = prediction[0][0]
6-
return {
7-
"class_label": iris_labels[predicted_class_id],
8-
"class_index": predicted_class_id,
9-
"probabilities": prediction[1][0],
10-
}
11+
xgb_model = xgb.XGBClassifier()
12+
xgb_model = xgb_model.fit(X_train, y_train)
13+
14+
print("Test data accuracy of the xgb classifier is {:.2f}".format(xgb_model.score(X_test, y_test)))
15+
16+
onnx_model = convert_xgboost(xgb_model, initial_types=[("input", FloatTensorType([1, 4]))])
17+
with open("xgboost.onnx", "wb") as f:
18+
f.write(onnx_model.SerializeToString())

0 commit comments

Comments
 (0)