Skip to content

Commit 34477a8

Browse files
committed
Clean up code for review
1 parent f729b22 commit 34477a8

File tree

8 files changed

+125
-78
lines changed

8 files changed

+125
-78
lines changed

assets/demo.css

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ progress::-webkit-progress-bar {
460460
transition: filter 0.1s linear, opacity 0.1s linear;
461461
}
462462

463+
/* Overwriting bootstrap tooltip styling */
463464
.tooltip {
464465
--bs-tooltip-arrow-height: 1.5rem;
465466
--bs-tooltip-arrow-width: 1.5rem;

demo_callbacks.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
generate_model_fig,
5353
)
5454

55+
from src.utils.global_vars import (
56+
STEP_4_FILE, STEP_5_FILE, LATENT_ENCODED_FILE, STEP_2_FILE, LATENT_QPU_FILE
57+
)
58+
5559

5660
@dash.callback(
5761
Output({"type": "to-collapse-class", "index": MATCH}, "className"),
@@ -116,8 +120,8 @@ def update_model_diagram_imgs(
116120
progress: int,
117121
fig_qpu: go.Figure,
118122
fig_encoded: go.Figure,
119-
latent_mapping: list
120-
) -> tuple[str, str, str, str, go.Figure, go.Figure, list]:
123+
latent_mapping: list,
124+
) -> tuple[str, str, str, go.Figure, go.Figure, list]:
121125
"""Force refresh images to get around Dash caching. Updates image src with a incrementing
122126
query string.
123127
@@ -138,10 +142,10 @@ def update_model_diagram_imgs(
138142
fig_qpu = go.Figure(fig_qpu)
139143
fig_encoded = go.Figure(fig_encoded)
140144

141-
with open("static/model_diagram/latent_qpu.json", "r") as f:
145+
with open(LATENT_QPU_FILE, "r") as f:
142146
latent_qpu = json.load(f)
143147

144-
with open("static/model_diagram/latent_encoded.json", "r") as f:
148+
with open(LATENT_ENCODED_FILE, "r") as f:
145149
latent_encoded = json.load(f)
146150

147151
color_mapping_qpu = [GRAPH_COLORS[int(latent_qpu[i] > 0)] for i in latent_mapping]
@@ -152,9 +156,9 @@ def update_model_diagram_imgs(
152156

153157
if GENERATE_NEW_MODEL_DIAGRAM:
154158
return (
155-
f"static/model_diagram/step_2_encode.png?interval={progress}",
156-
f"static/model_diagram/step_4_decode.png?interval={progress}",
157-
f"static/model_diagram/step_5_output.png?interval={progress}",
159+
f"{STEP_2_FILE}?interval={progress}",
160+
f"{STEP_4_FILE}?interval={progress}",
161+
f"{STEP_5_FILE}?interval={progress}",
158162
fig_qpu,
159163
fig_encoded,
160164
generate_latent_vector(latent_encoded[:5], latent_encoded[-1]),
@@ -163,6 +167,20 @@ def update_model_diagram_imgs(
163167
raise PreventUpdate
164168

165169

170+
class CheckQpuAndUpdateModelReturn(NamedTuple):
171+
"""Return type for the ``check_qpu_and_update_model`` callback function."""
172+
173+
popup_classname: str = "display-none"
174+
generate_button_disabled: bool = False
175+
model_details: dict = dash.no_update
176+
fig_qpu_graph: go.Figure = dash.no_update
177+
fig_encoded_graph: go.Figure = dash.no_update
178+
latent_diagram_size: int = dash.no_update
179+
latent_mapping: list[int] = dash.no_update
180+
step_2_img: str = dash.no_update
181+
step_4_img: str = dash.no_update
182+
step_5_img: str = dash.no_update
183+
166184
@dash.callback(
167185
Output("popup", "className"),
168186
Output("generate-button", "disabled"),
@@ -188,7 +206,7 @@ def check_qpu_and_update_model(
188206
n_latents: int,
189207
setting_tabs_value: str,
190208
example_image: list,
191-
) -> tuple[str, bool, dict, go.Figure, go.Figure, int, list[int]]:
209+
) -> CheckQpuAndUpdateModelReturn:
192210
"""Checks whether user has access to QPU associated with model and updates the model details
193211
when model changes.
194212
@@ -200,61 +218,71 @@ def check_qpu_and_update_model(
200218
example_image: The example image to show all the steps for in the UI.
201219
202220
Returns:
203-
popup-classname: The class name to hide the popup.
204-
generate-button-disabled: Whether to disable or enable the Generate button.
205-
model-details-children: The model details to display.
206-
fig-qpu-graph: The QPU graph figure.
207-
fig-encoded-graph: The not QPU graph figure.
208-
latent-diagram-size: The dimension of the latent space.
209-
latent-mapping: The mapping of the nodes to latent space indices.
210-
step-2-encode-img: The src url for the encode image.
211-
step-4-decode-img: The src url for the decode image.
212-
step-5-output-img: The src url for the output image.
221+
CheckQpuAndUpdateModelReturn named tuple:
222+
popup_classname: The class name to hide the popup.
223+
generate_button_disabled: Whether to disable or enable the Generate button.
224+
model_details: The model details to display.
225+
fig_qpu_graph: The QPU graph figure.
226+
fig_encoded_graph: The not QPU graph figure.
227+
latent_diagram_size: The dimension of the latent space.
228+
latent_mapping: The mapping of the nodes to latent space indices.
229+
step_2_img: The src url for the encode image.
230+
step_4_img: The src url for the decode image.
231+
step_5_img: The src url for the output image.
213232
"""
214-
model_data = None
215-
model_details = None
233+
switched_to_generate_tab = ctx.triggered_id == "setting-tabs" and setting_tabs_value == "generate-tab"
216234

217-
if not ctx.triggered_id or ctx.triggered_id == "model-file-name" or (ctx.triggered_id == "setting-tabs" and setting_tabs_value == "generate-tab"):
235+
# If first load, or a new model is chosen, or the settings tab is changed to "generate"
236+
if not ctx.triggered_id or ctx.triggered_id == "model-file-name" or switched_to_generate_tab:
218237
with open(MODEL_PATH / model_file_name / "parameters.json") as file:
219238
model_data = json.load(file)
220239

221240
model_details = generate_model_data(model_data)
222241

242+
# If model_data has a QPU that is no longer available, show warning popup
243+
if model_data["qpu"] and not (len(SOLVERS) and model_data["qpu"] in SOLVERS):
244+
return CheckQpuAndUpdateModelReturn(
245+
popup_classname="",
246+
generate_button_disabled=True,
247+
model_details=model_details,
248+
)
249+
223250
# Create model instance to generate model diagram images
224251
model = ModelWrapper(qpu=model_data["qpu"], n_latents=model_data["n_latents"])
225252
model.load(file_path=MODEL_PATH / model_file_name)
253+
226254
# Dash converts the tensor to a list of floats, convert back to tensor
227255
example_image = torch.tensor(example_image, dtype=torch.float32)
228256
example_image = example_image.unsqueeze(0)
229-
230257
generate_model_diagram(model, example_image)
258+
259+
fig_qpu, fig_encoded, latent_mapping = generate_model_fig(
260+
model_data["qpu"],
261+
model_data["n_latents"],
262+
model_data["random_seed"],
263+
)
264+
231265
force_refresh = random.randint(1, 9999999)
232266

233-
if model_data["qpu"] and not (len(SOLVERS) and model_data["qpu"] in SOLVERS):
234-
return (
235-
"", True, model_details, dash.no_update,
236-
f"static/model_diagram/step_2_encode.png?force_refresh={force_refresh}",
237-
f"static/model_diagram/step_4_decode.png?force_refresh={force_refresh}",
238-
f"static/model_diagram/step_5_output.png?force_refresh={force_refresh}",
239-
)
267+
return CheckQpuAndUpdateModelReturn(
268+
model_details=model_details,
269+
fig_qpu_graph=fig_qpu,
270+
fig_encoded_graph=fig_encoded,
271+
latent_diagram_size=model_data["n_latents"],
272+
latent_mapping=latent_mapping,
273+
step_2_img=f"{STEP_2_FILE}?force_refresh={force_refresh}",
274+
step_4_img=f"{STEP_4_FILE}?force_refresh={force_refresh}",
275+
step_5_img=f"{STEP_5_FILE}?force_refresh={force_refresh}",
276+
)
240277

241-
fig_qpu, fig_encoded, latent_mapping = generate_model_fig(
242-
model_data["qpu"] if model_data else qpu,
243-
model_data["n_latents"] if model_data else n_latents,
244-
model_data["random_seed"] if model_data else 4,
245-
)
278+
# No model data, proceed with defaults
279+
fig_qpu, fig_encoded, latent_mapping = generate_model_fig(qpu, n_latents, 4)
246280

247-
return (
248-
"display-none",
249-
False,
250-
model_details if model_details else dash.no_update,
251-
fig_qpu,
252-
fig_encoded,
253-
n_latents,
254-
latent_mapping,
255-
dash.no_update,
256-
dash.no_update,
257-
dash.no_update,
281+
return CheckQpuAndUpdateModelReturn(
282+
fig_qpu_graph=fig_qpu,
283+
fig_encoded_graph=fig_encoded,
284+
latent_diagram_size=n_latents,
285+
latent_mapping=latent_mapping,
258286
)
259287

260288

demo_interface.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
THUMBNAIL,
3535
)
3636
from src.utils.callback_helpers import get_example_image
37+
from src.utils.global_vars import LATENT_ENCODED_FILE, STEP_1_FILE, STEP_2_FILE, STEP_4_FILE, STEP_5_FILE_DEFAULT
3738

3839
# Initialize available QPUs
3940
try:
@@ -48,7 +49,7 @@
4849

4950
# Initialize the latent diagram with either the available file or random +/- 1s
5051
try:
51-
with open("static/model_diagram/latent_encoded.json", "r") as f:
52+
with open(LATENT_ENCODED_FILE, "r") as f:
5253
latent_qpu = json.load(f)
5354

5455
LATENT_DIAGRAM_START = latent_qpu[:5]
@@ -564,13 +565,13 @@ def create_interface():
564565
html.Div(
565566
[
566567
html.Img(
567-
src="static/model_diagram/step_1_input.png",
568+
src=STEP_1_FILE,
568569
id="step-1-input-img",
569570
),
570571
html.Div([
571572
html.Div(className="forward-arrow"),
572573
html.Img(
573-
src="static/model_diagram/step_2_encode.png",
574+
src=STEP_2_FILE,
574575
id="step-2-encode-img",
575576
),
576577
], className="graph-model-itermediate-step"),
@@ -595,9 +596,9 @@ def create_interface():
595596
),
596597
html.Div([
597598
html.Div(className="forward-arrow"),
598-
html.Img(src="static/model_diagram/step_4_decode.png", id="step-4-decode-img"),
599+
html.Img(src=STEP_4_FILE, id="step-4-decode-img"),
599600
], className="graph-model-itermediate-step"),
600-
html.Img(src="static/model_diagram/step_5_output_default.png", id="step-5-output-img"),
601+
html.Img(src=STEP_5_FILE_DEFAULT, id="step-5-output-img"),
601602
],
602603
className="graph-model-wrapper"
603604
),

src/model_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from typing import Optional
1919

20+
from src.utils.global_vars import LATENT_QPU_FILE
2021
import numpy as np
2122
import plotly.express as px
2223
import torch
@@ -368,7 +369,7 @@ def generate_output(self, sharpen: bool = False, save_to_file: str = "") -> go.F
368369
sample_params=self.sampler_kwargs,
369370
)
370371

371-
with open("static/model_diagram/latent_qpu.json", "w") as f:
372+
with open(LATENT_QPU_FILE, "w") as f:
372373
json.dump(samples[0].tolist(), f)
373374

374375
images = self._dvae.decoder(samples.unsqueeze(1)).squeeze(1).clip(0.0, 1.0).detach().cpu()

src/training_parameters.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ANNEALING_TIME: 1
22
NUM_READS: 256
33

44
IMAGE_SIZE: 32
5-
DATASET_SIZE: 60000 # full MNIST is 60_000
5+
DATASET_SIZE: null # full MNIST is 60_000
66
BATCH_SIZE: 128
77
RANDOM_SEED: 775321899904
88
LOSS_FUNCTION: mmd

0 commit comments

Comments
 (0)