Skip to content

Commit 728bb98

Browse files
committed
Add example of using scalars
1 parent bc0fdab commit 728bb98

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

examples/iris/handlers/sklearn.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
1+
from cortex.lib.log import get_logger
12
import numpy as np
3+
import boto3
4+
import json
5+
6+
logger = get_logger()
7+
s3 = boto3.client("s3")
28

39
iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
410

11+
scalars_obj = s3.get_object(Bucket="cortex-examples", Key="iris/scalars.json")
12+
scalars = json.loads(scalars_obj["Body"].read().decode("utf-8"))
13+
logger.info("downloaded scalars: {}".format(scalars))
14+
515

616
def pre_inference(sample, metadata):
7-
return [
8-
sample["sepal_length"],
9-
sample["sepal_width"],
10-
sample["petal_length"],
11-
sample["petal_width"],
12-
]
17+
x = np.array(
18+
[
19+
sample["sepal_length"],
20+
sample["sepal_width"],
21+
sample["petal_length"],
22+
sample["petal_width"],
23+
]
24+
)
25+
return ((x - scalars["mean"]) / scalars["stddev"]).astype(np.float32)
1326

1427

1528
def post_inference(prediction, metadata):

examples/iris/models/sklearn_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import numpy as np
12
from sklearn.datasets import load_iris
23
from sklearn.model_selection import train_test_split
34
from sklearn.linear_model import LogisticRegression
5+
from sklearn.preprocessing import StandardScaler
46
from onnxmltools import convert_sklearn
57
from onnxconverter_common.data_types import FloatTensorType
68

79
iris = load_iris()
810
X, y = iris.data, iris.target
11+
12+
scaler = StandardScaler()
13+
X = scaler.fit_transform(X)
14+
print("mean:", scaler.mean_)
15+
print("standard deviation:", np.sqrt(scaler.var_))
16+
17+
918
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
1019

1120
logreg_model = LogisticRegression(solver="lbfgs", multi_class="multinomial")

0 commit comments

Comments
 (0)