Skip to content

Commit 42b677c

Browse files
authored
Added "Remove Model" button & existing button alignment, rename
1 parent dd7605f commit 42b677c

File tree

1 file changed

+62
-31
lines changed

1 file changed

+62
-31
lines changed

Horizontal View - app.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,24 @@
44

55
st.set_page_config(page_title="LLM Comparison", layout="wide")
66

7+
st.markdown("""
8+
<style>
9+
.stButton button {
10+
padding: 0px 5px !important;
11+
min-width: unset !important;
12+
font-size: 10px !important;
13+
height: 25px !important;
14+
line-height: 1 !important;
15+
margin-top: 28px !important;
16+
}
17+
div[data-testid="stSelectbox"] > div {
18+
margin-right: 0px !important;
19+
}
20+
</style>
21+
""", unsafe_allow_html=True)
22+
723
st.title("Running LLMs in parallel")
824

9-
# Get available models
1025
@st.cache_data
1126
def get_models():
1227
try:
@@ -27,42 +42,58 @@ def get_models():
2742
if "model_count" not in st.session_state:
2843
st.session_state.model_count = 2
2944
if "selected_models" not in st.session_state:
30-
st.session_state.selected_models = ["", ""] # initial 2 models
45+
st.session_state.selected_models = ["", ""]
3146

32-
# Add model dynamically
33-
if st.button("Add new model"):
34-
st.session_state.model_count += 1
35-
st.session_state.selected_models.append("")
47+
def remove_model(index):
48+
if st.session_state.model_count > 1:
49+
st.session_state.model_count -= 1
50+
st.session_state.selected_models.pop(index)
3651

37-
# Display model selectors
3852
for i in range(st.session_state.model_count):
39-
st.session_state.selected_models[i] = st.selectbox(
40-
f"Model {i+1}",
41-
models_available,
42-
index=0 if i >= len(st.session_state.selected_models) or not st.session_state.selected_models[i] else models_available.index(st.session_state.selected_models[i]),
43-
key=f"model_select_{i}"
44-
)
45-
46-
# Background colors
47-
box_colors = ["#e6f0ff", "#ffe6e6", "#e6ffe6", "#fff0e6", "#f0e6ff"]
48-
49-
# Run button
50-
if st.button("Generate") and prompt.strip():
51-
model_inputs = st.session_state.selected_models
53+
col1, col2 = st.columns([0.97, 0.02])
54+
with col1:
55+
st.session_state.selected_models[i] = st.selectbox(
56+
f"Model {i+1}",
57+
models_available,
58+
index=0 if i >= len(st.session_state.selected_models) or not st.session_state.selected_models[i] else models_available.index(st.session_state.selected_models[i]),
59+
key=f"model_select_{i}"
60+
)
61+
with col2:
62+
st.button("✖", key=f"remove_model_{i}", on_click=remove_model, args=(i,))
63+
64+
selected_models_filtered = [model for model in st.session_state.selected_models if model]
65+
66+
_, _, spacer, col_add, col_run = st.columns([0.5, 0.2, 0.1, 0.1, 0.1])
67+
with col_add:
68+
if st.button("Add new model"):
69+
st.session_state.model_count += 1
70+
st.session_state.selected_models.append("")
71+
st.rerun()
72+
with col_run:
73+
run_clicked = st.button("Run Models", type="primary")
74+
75+
if run_clicked and prompt and selected_models_filtered:
5276
responses = []
5377

54-
for model in model_inputs:
55-
start = time.time()
78+
response_placeholders = [st.empty() for _ in selected_models_filtered]
79+
80+
for i, model in enumerate(selected_models_filtered):
5681
try:
57-
response = requests.post(
58-
"http://localhost:11434/api/generate",
59-
json={"model": model, "prompt": prompt, "stream": False},
60-
).json()
61-
62-
duration = round(time.time() - start, 2)
63-
content = response.get("response", "").strip()
64-
eval_count = response.get("eval_count", len(content.split()))
65-
eval_rate = response.get("eval_rate", round(eval_count / duration, 2))
82+
with st.spinner(f"Running {model}..."):
83+
start_time = time.time()
84+
res = requests.post(
85+
"http://localhost:11434/api/generate",
86+
json={"model": model, "prompt": prompt, "stream": False},
87+
headers={"Content-Type": "application/json"},
88+
)
89+
res.raise_for_status()
90+
response_data = res.json()
91+
end_time = time.time()
92+
93+
duration = round(end_time - start_time, 2)
94+
content = response_data.get("response", "")
95+
eval_count = response_data.get("eval_count", len(content.split()))
96+
eval_rate = response_data.get("eval_rate", round(eval_count / duration, 2))
6697

6798
responses.append({
6899
"model": model,

0 commit comments

Comments
 (0)