Skip to content

Commit 9ab2abf

Browse files
committed
Add Streamlit app for predicting Iris flower species
1 parent a64b6f9 commit 9ab2abf

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

app/pages/4_Classification.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Create a Streamlit app that takes in features about the iris flower and returns the species of the flower.
3+
"""
4+
import streamlit as st
5+
import matplotlib.pyplot as plt
6+
from sklearn.decomposition import PCA
7+
from sklearn.datasets import load_iris
8+
import requests
9+
10+
def load_data():
11+
iris = load_iris()
12+
X = iris.data
13+
y = iris.target
14+
return X, y
15+
16+
17+
def app():
18+
19+
# Add title and description
20+
st.title("Predicting Iris Flower Species")
21+
22+
# Sidebar with input fields
23+
sepal_length = st.slider("Sepal Length", 0.0, 10.0, 5.0)
24+
sepal_width = st.slider("Sepal Width", 0.0, 10.0, 5.0)
25+
petal_length = st.slider("Petal Length", 0.0, 10.0, 5.0)
26+
petal_width = st.slider("Petal Width", 0.0, 10.0, 5.0)
27+
28+
# Create input data (dictionary)
29+
input_data = {
30+
"sepal_length": sepal_length,
31+
"sepal_width": sepal_width,
32+
"petal_length": petal_length,
33+
"petal_width": petal_width
34+
}
35+
36+
# interact with FastAPI endpoint
37+
response = requests.post("http://127.0.0.1:8000/predict", json=input_data)
38+
39+
target_names = ['Setosa', 'Versicolor', 'Virginica']
40+
if st.button("Predict"):
41+
# make a post request to the FastAPI endpoint
42+
prediction = int(response.json()[0])
43+
pred_prob = response.json()[1]
44+
45+
# print results
46+
st.write(f"Species predicted: {target_names[prediction]} with {pred_prob:.2f}% confidence")
47+
48+
# We will plot how the train data clusters in 2D space and then see how the test data fits in it.
49+
# First apply PCA to reduce the dimensionality of the data to 2D
50+
X, y = load_data()
51+
pca = PCA(n_components=2)
52+
X_pca = pca.fit_transform(X)
53+
54+
data_test = [[sepal_length, sepal_width, petal_length, petal_width]]
55+
data_test = pca.transform(data_test)
56+
57+
# add class labels as legend
58+
fig, ax = plt.subplots()
59+
for i in range(3):
60+
ax.scatter(X_pca[y==i, 0], X_pca[y==i, 1], label=target_names[i])
61+
62+
# Use "test_data" as label for the test data
63+
ax.scatter(data_test[:, 0], data_test[:, 1], c='red', marker='x', s=100, label='test_data')
64+
ax.set_xlabel('First Principal Component')
65+
ax.set_ylabel('Second Principal Component')
66+
ax.set_title('Train data')
67+
ax.legend()
68+
st.pyplot(fig)
69+
70+
if __name__ == "__main__":
71+
app()

0 commit comments

Comments
 (0)