Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions examples/fashion-mnist-parallel-coords-6d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Based on the Fashion MNIST Parallel Coordinates notebook by the marimo team:
# https://github.com/marimo-team/gallery-examples/blob/main/notebooks/wigglystuff/fashion-mnist-parallel-coords.py
#
# Extended with anywidget-vector 3D scatter view and 6D dimension mapping.
#
# /// script
# requires-python = ">=3.12"
# dependencies = [
Expand All @@ -14,7 +19,7 @@

import marimo

__generated_with = "0.20.2"
__generated_with = "0.19.7"
app = marimo.App(width="full")


Expand All @@ -36,7 +41,7 @@ def _():
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""
# Fashion MNIST: Parallel Coordinates + 3D Vector View
# Fashion MNIST: Parallel Coordinates + 6D Vector View

Brush the parallel coordinates axes to filter. The 3D scatter updates in real time.
Use the dropdowns to map PCA dimensions to visual channels.
Expand Down Expand Up @@ -122,9 +127,9 @@ def _(mo, n_components_slider):
x_dim = mo.ui.dropdown(options=pcs, value="PC1", label="X")
y_dim = mo.ui.dropdown(options=pcs, value="PC2", label="Y")
z_dim = mo.ui.dropdown(options=pcs, value="PC3", label="Z")
size_dim = mo.ui.dropdown(options=["none", *pcs], value="none", label="Size")
size_dim = mo.ui.dropdown(options=["none", *pcs], value="PC4", label="Size")
color_dim = mo.ui.dropdown(options=["label", *pcs], value="label", label="Color")
shape_dim = mo.ui.dropdown(options=["label", "none"], value="none", label="Shape")
shape_dim = mo.ui.dropdown(options=["label", "none", *pcs], value="PC5", label="Shape")
mo.hstack([x_dim, y_dim, z_dim, size_dim, color_dim, shape_dim], gap=0.5)
return color_dim, shape_dim, size_dim, x_dim, y_dim, z_dim

Expand All @@ -140,6 +145,7 @@ def _(
label_names,
labels,
mo,
np,
shape_dim,
size_dim,
x_dim,
Expand All @@ -152,8 +158,23 @@ def _pc(name):
_xi, _yi, _zi = _pc(x_dim.value), _pc(y_dim.value), _pc(z_dim.value)
_color_pc = None if color_dim.value == "label" else _pc(color_dim.value)
_size_pc = None if size_dim.value == "none" else _pc(size_dim.value)
_shape_mode = shape_dim.value # "label", "none", or a PC name

# Shape mapping: label-based or PC-based (binned into 6 shape buckets)
_unique_labels = sorted(set(label_names.values()))
_shape_map = {_n: SHAPE_NAMES[_j % len(SHAPE_NAMES)] for _j, _n in enumerate(_unique_labels)}
_shape_bins = None
if _shape_mode not in ("label", "none"):
_shape_col = components[:, _pc(_shape_mode)]
_quantiles = np.percentile(_shape_col, np.linspace(0, 100, len(SHAPE_NAMES) + 1))
_shape_bins = list(zip(_quantiles[:-1], _quantiles[1:], SHAPE_NAMES, strict=False))
_shape_map = {_s: _s for _s in SHAPE_NAMES}

def _get_shape_bin(_val):
for _lo, _hi, _s in _shape_bins:
if _val <= _hi:
return _s
return SHAPE_NAMES[-1]

vs_points = []
for _i in range(len(components)):
Expand All @@ -171,8 +192,10 @@ def _pc(name):
_p["color_val"] = float(components[_i, _color_pc])
if _size_pc is not None:
_p["size_val"] = float(components[_i, _size_pc])
if shape_dim.value == "label":
if _shape_mode == "label":
_p["shape_cat"] = _name
elif _shape_bins is not None:
_p["shape_cat"] = _get_shape_bin(components[_i, _pc(_shape_mode)])
vs_points.append(_p)

_vs_kwargs = {
Expand All @@ -183,25 +206,23 @@ def _pc(name):
_vs_kwargs["color_scale"] = "viridis"
if _size_pc is not None:
_vs_kwargs["size_field"] = "size_val"
_vs_kwargs["size_range"] = [0.01, 0.06]
if shape_dim.value == "label":
_vs_kwargs["size_range"] = [0.01, 0.03]
if _shape_mode != "none":
_vs_kwargs["shape_field"] = "shape_cat"
_vs_kwargs["shape_map"] = _shape_map

vs_widget = VectorSpace(
points=vs_points,
width=1200,
width=1600,
height=500,
dark_mode=False,
background="#fafafa",
show_toolbar=True,
show_settings=True,
show_properties=False,
**_vs_kwargs,
)
vs = mo.ui.anywidget(vs_widget)
vs
return vs, vs_points, vs_widget
return vs_points, vs_widget


@app.cell
Expand All @@ -225,17 +246,29 @@ def _(LABEL_COLORS, color_dim, mo, vs_points, vs_widget, widget):


@app.cell
def _(idx, images, label_names, labels, np, plt, widget):
def _(LABEL_COLORS, idx, images, label_names, labels, np, plt, vs_widget, widget):
_filtered = widget.widget.filtered_indices
_sample_idx = np.array(_filtered[:10]) if len(_filtered) >= 10 else np.array(_filtered)

_fig, _axes = plt.subplots(1, len(_sample_idx), figsize=(2 * len(_sample_idx), 2))
# Which points are selected in the 3D view?
_selected_ids = set(vs_widget.selected_points or [])

_fig, _axes = plt.subplots(1, len(_sample_idx), figsize=(2 * len(_sample_idx), 2.4))
if len(_sample_idx) == 1:
_axes = [_axes]
for _ax, _si in zip(_axes, _sample_idx, strict=False):
_name = label_names[labels[idx[_si]]]
_ax.imshow(images[idx[_si]].reshape(28, 28), cmap="gray")
_ax.set_title(label_names[labels[idx[_si]]], fontsize=9)
_ax.set_title(_name, fontsize=9)
_ax.axis("off")
# Highlight if selected in 3D view
if f"p_{_si}" in _selected_ids:
_color = LABEL_COLORS.get(_name, "#0880ea")
for _spine in _ax.spines.values():
_spine.set_visible(True)
_spine.set_color(_color)
_spine.set_linewidth(3)
_ax.set_title(_name, fontsize=9, fontweight="bold", color=_color)
plt.tight_layout()
_fig
return
Expand Down
39 changes: 35 additions & 4 deletions src/anywidget_vector/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,47 @@ def get_esm() -> str:
function render({{ model, el }}) {{
const wrapper = document.createElement("div");
wrapper.className = "avs-wrapper";
if (model.get("dark_mode")) wrapper.classList.add("avs-dark");
wrapper.style.width = model.get("width") + "px";
wrapper.style.height = model.get("height") + "px";
el.appendChild(wrapper);

model.on("change:dark_mode", () => {{
const dark = model.get("dark_mode");
// Auto-detect host theme (marimo uses Tailwind class="dark" on <html>)
function detectHostDark() {{
const html = document.documentElement;
if (html.classList.contains("dark")) return true;
if (html.dataset.theme === "dark") return true;
if (window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)").matches) return true;
return false;
}}

function applyTheme(dark) {{
wrapper.classList.toggle("avs-dark", dark);
model.set("background", dark ? "#1a1a2e" : "#fafafa");
model.set("background", dark ? "#181c1a" : "#ffffff");
model.save_changes();
}}

// Detect if we're inside a themed host (marimo, jupyter, etc.)
const hasHostTheme = !!el.getRootNode()?.host?.tagName?.startsWith("MARIMO-");
if (hasHostTheme) {{
wrapper.classList.add("avs-auto-theme");
model.set("dark_mode", detectHostDark());
applyTheme(detectHostDark());
// Watch host for live theme changes
const themeObserver = new MutationObserver(() => {{
const dark = detectHostDark();
if (model.get("dark_mode") !== dark) {{
model.set("dark_mode", dark);
applyTheme(dark);
}}
}});
themeObserver.observe(document.documentElement, {{ attributes: true, attributeFilter: ["class", "data-theme"] }});
}} else {{
applyTheme(model.get("dark_mode"));
}}

// Manual toggle from settings panel still works
model.on("change:dark_mode", () => {{
applyTheme(model.get("dark_mode"));
}});

let sidebar = null;
Expand Down
16 changes: 12 additions & 4 deletions src/anywidget_vector/ui/canvas.js
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,20 @@ export function createCanvas(model, container, callbacks) {
const points = model.get("points") || [];
if (points.length === 0) return;

// Auto-scale point sizes relative to data extent
const box = new THREE.Box3();
points.forEach(p => box.expandByPoint(new THREE.Vector3(p.x ?? 0, p.y ?? 0, p.z ?? 0)));
const dataSize = box.getSize(new THREE.Vector3()).length() || 1;
const scaleFactor = dataSize / 10;
const rawRange = model.get("size_range") || [0.02, 0.06];
const sizeRange = [rawRange[0] * scaleFactor, rawRange[1] * scaleFactor];

const opts = {
colorField: model.get("color_field"),
colorScale: model.get("color_scale") || "viridis",
colorDomain: model.get("color_domain"),
sizeField: model.get("size_field"),
sizeRange: model.get("size_range") || [0.02, 0.1],
sizeRange: sizeRange,
shapeField: model.get("shape_field"),
shapeMap: model.get("shape_map") || {},
};
Expand Down Expand Up @@ -460,7 +468,7 @@ export function createCanvas(model, container, callbacks) {
const size = box.getSize(new THREE.Vector3()).length();
const distance = size / (2 * Math.tan(Math.PI * camera.fov / 360));
controls.target.copy(center);
camera.position.copy(center.clone().add(new THREE.Vector3(0, 0, distance * 1.2)));
camera.position.copy(center.clone().add(new THREE.Vector3(0, 0, distance * 0.55)));
camera.near = Math.max(0.01, distance * 0.001);
camera.far = Math.max(1000, distance * 10);
camera.updateProjectionMatrix();
Expand All @@ -485,8 +493,8 @@ export function createCanvas(model, container, callbacks) {
zoomOutBtn.innerHTML = ICONS.zoomOut;
zoomOutBtn.title = "Zoom out";
zoomOutBtn.addEventListener("click", () => {
const dir = camera.position.clone().sub(controls.target).normalize();
camera.position.add(dir.multiplyScalar(0.5));
const offset = camera.position.clone().sub(controls.target);
camera.position.add(offset.multiplyScalar(0.3));
controls.update();
});

Expand Down
3 changes: 2 additions & 1 deletion src/anywidget_vector/ui/settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ export function createSettingsPanel(model, callbacks) {
header.appendChild(closeBtn);
inner.appendChild(header);

// Dark mode toggle
// Dark mode toggle (hidden when host provides theme)
const themeGroup = createFormGroup("Theme");
themeGroup.classList.add("avs-theme-toggle");
const toggle = document.createElement("label");
toggle.className = "avs-toggle";
const checkbox = document.createElement("input");
Expand Down
122 changes: 122 additions & 0 deletions src/anywidget_vector/ui/sidebar.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,22 @@ export function createSidebar(model, callbacks) {
addStatRow(dimContent, "X range", rangeStr(xs));
addStatRow(dimContent, "Y range", rangeStr(ys));
addStatRow(dimContent, "Z range", rangeStr(zs));

// Visual mapping dimensions
const colorField = model.get("color_field");
const sizeField = model.get("size_field");
const shapeField = model.get("shape_field");
if (colorField) addStatRow(dimContent, "Color", colorField);
if (sizeField) {
const sizeRange = model.get("size_range") || [0.02, 0.1];
addStatRow(dimContent, "Size", sizeField + " [" + sizeRange[0] + "," + sizeRange[1] + "]");
}
if (shapeField) {
const shapeMap = model.get("shape_map") || {};
const shapes = Object.values(shapeMap);
const unique = [...new Set(shapes)];
addStatRow(dimContent, "Shape", shapeField + " (" + unique.length + " shapes)");
}
}

function updateClusters() {
Expand Down Expand Up @@ -159,6 +175,107 @@ export function createSidebar(model, callbacks) {
});
}

// === Distance Section ===
const distSection = document.createElement("div");
distSection.className = "avs-sidebar-section";

const distHeader = document.createElement("div");
distHeader.className = "avs-section-header";
distHeader.textContent = "Distance";
distSection.appendChild(distHeader);

const distContent = document.createElement("div");
distContent.className = "avs-dimension-content";

// Metric selector
const metricGroup = document.createElement("div");
metricGroup.className = "avs-form-group";
const metricLabel = document.createElement("label");
metricLabel.className = "avs-label";
metricLabel.textContent = "Metric";
metricGroup.appendChild(metricLabel);
const metricSelect = document.createElement("select");
metricSelect.className = "avs-select";
metricSelect.style.width = "100%";
["euclidean", "cosine", "manhattan", "dot_product"].forEach(m => {
const opt = document.createElement("option");
opt.value = m;
opt.textContent = m;
opt.selected = m === (model.get("distance_metric") || "euclidean");
metricSelect.appendChild(opt);
});
metricSelect.addEventListener("change", () => {
model.set("distance_metric", metricSelect.value);
model.save_changes();
});
metricGroup.appendChild(metricSelect);
distContent.appendChild(metricGroup);

// K neighbors
const kGroup = document.createElement("div");
kGroup.className = "avs-form-group";
const kLabel = document.createElement("label");
kLabel.className = "avs-label";
kLabel.textContent = "K neighbors";
kGroup.appendChild(kLabel);
const kInput = document.createElement("input");
kInput.type = "number";
kInput.className = "avs-input";
kInput.min = "0";
kInput.max = "50";
kInput.value = model.get("k_neighbors") || 0;
kInput.addEventListener("input", () => {
const k = parseInt(kInput.value, 10) || 0;
model.set("k_neighbors", k);
model.set("show_connections", k > 0);
model.save_changes();
});
kGroup.appendChild(kInput);
distContent.appendChild(kGroup);

// Distance info (updates when points are selected)
const distInfo = document.createElement("div");
distInfo.className = "avs-note";
distInfo.textContent = "Select a point to see distances";
distContent.appendChild(distInfo);

distSection.appendChild(distContent);
inner.appendChild(distSection);

function updateDistanceInfo() {
const selected = model.get("selected_points") || [];
if (selected.length === 0) {
distInfo.textContent = "Select a point to see distances";
return;
}
const points = model.get("points") || [];
const ref = points.find(p => p.id === selected[0]);
if (!ref) return;

const metric = metricSelect.value;
const others = points.filter(p => p.id !== ref.id);

function dist(a, b) {
const dx = (a.x ?? 0) - (b.x ?? 0), dy = (a.y ?? 0) - (b.y ?? 0), dz = (a.z ?? 0) - (b.z ?? 0);
if (metric === "manhattan") return Math.abs(dx) + Math.abs(dy) + Math.abs(dz);
if (metric === "cosine") {
const dot = (a.x??0)*(b.x??0) + (a.y??0)*(b.y??0) + (a.z??0)*(b.z??0);
const ma = Math.sqrt((a.x??0)**2 + (a.y??0)**2 + (a.z??0)**2);
const mb = Math.sqrt((b.x??0)**2 + (b.y??0)**2 + (b.z??0)**2);
return (ma === 0 || mb === 0) ? 1 : 1 - dot / (ma * mb);
}
if (metric === "dot_product") return -((a.x??0)*(b.x??0) + (a.y??0)*(b.y??0) + (a.z??0)*(b.z??0));
return Math.sqrt(dx*dx + dy*dy + dz*dz);
}

const sorted = others.map(p => ({ id: p.id, label: p.label || p.id, d: dist(ref, p) }))
.sort((a, b) => a.d - b.d).slice(0, 5);

distInfo.innerHTML = "<strong>" + escapeHtml(ref.label || ref.id) + "</strong><br>" +
sorted.map(n => '<span class="avs-property-key">' + escapeHtml(n.label) +
'</span> <span class="avs-property-value">' + n.d.toFixed(2) + "</span>").join("<br>");
}

// === Event Bindings ===
model.on("change:points", () => {
updateDimensions();
Expand All @@ -167,6 +284,11 @@ export function createSidebar(model, callbacks) {
model.on("change:color_field", updateClusters);
model.on("change:backend_config", updateCollections);
model.on("change:backend", updateCollections);
model.on("change:selected_points", updateDistanceInfo);
model.on("change:distance_metric", () => {
metricSelect.value = model.get("distance_metric") || "euclidean";
updateDistanceInfo();
});

// Initial render
updateCollections();
Expand Down
Loading
Loading