Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions api/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,16 @@ def main() -> None:
send({"type": "ready", "params_schema": schema})

# Support both flat manifest (legacy) and nodes[] format.
# Node-level fields take precedence; fall back to top-level for compatibility.
node = (manifest.get("nodes") or [{}])[0]
# Use MODEL_DIR to find the correct node for multi-node extensions:
# MODEL_DIR is set by ExtensionProcess to MODELS_DIR/ext_id/node_id,
# so its last component matches the node id.
nodes = manifest.get("nodes") or []
node = {}
if nodes and _MODEL_DIR_OVERRIDE:
node_id = Path(_MODEL_DIR_OVERRIDE).name
node = next((n for n in nodes if n.get("id") == node_id), nodes[0])
elif nodes:
node = nodes[0]

# Use MODEL_DIR env var (set by ExtensionProcess) when available so the
# generator uses the exact same path that is_downloaded() checks against.
Expand Down
11 changes: 6 additions & 5 deletions api/services/generator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ def get_active(self) -> BaseGenerator:
if not gen.is_loaded():
if not gen.is_downloaded():
if isinstance(gen, ExtensionProcess):
raise RuntimeError(
f"Model '{self._active_id}' is not downloaded. "
"Please install it from the Models page first."
)
gen._auto_download()
# Let the subprocess handle its own download logic during
# load() — some extensions (e.g. mv-adapter) need custom
# multi-repo downloads that the standard HF endpoint can't do.
pass
else:
gen._auto_download()
gen.load()
return gen

Expand Down
2 changes: 2 additions & 0 deletions electron/main/ipc-handlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ export function setupIpcHandlers(pythonBridge: PythonBridge, getWindow: WindowGe
id: string
name?: string
input?: 'mesh' | 'image' | 'text'
inputs?: ('mesh' | 'image' | 'text')[]
output?: 'mesh' | 'image' | 'text'
params_schema?: unknown[]
hf_repo?: string
Expand All @@ -534,6 +535,7 @@ export function setupIpcHandlers(pythonBridge: PythonBridge, getWindow: WindowGe
id: n.id,
name: n.name ?? n.id,
input: n.input ?? 'image' as const,
inputs: n.inputs,
output: n.output ?? 'mesh' as const,
paramsSchema: n.params_schema ?? [],
hfRepo: n.hf_repo,
Expand Down
30 changes: 20 additions & 10 deletions src/areas/generate/components/WorkflowPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -420,20 +420,30 @@ function EmbeddedCanvas({ workflow, allExtensions }: {
if (out) updateNodeData(out.id, { params: { outputUrl: runState.outputUrl } })
}, [runState.status, runState.outputUrl])

// Type mismatch detection
// Type mismatch detection — edge-based to support multi-input nodes
const typeMismatch = useMemo(() => {
const sorted = topoSortNodes(workflow.nodes, workflow.edges)
const extNodes = sorted.filter((n) => n.type === 'extensionNode')
// Determine initial type from the actual source node in the graph
const firstSource = sorted.find((n) => n.type === 'imageNode' || n.type === 'meshNode' || n.type === 'textNode')
let prev: string = firstSource?.type === 'meshNode' ? 'mesh'
: firstSource?.type === 'textNode' ? 'text'
: 'image'
// Build a map of what type each node produces
const nodeOutput = new Map<string, string>()
for (const node of workflow.nodes) {
if (node.type === 'imageNode') { nodeOutput.set(node.id, 'image'); continue }
if (node.type === 'meshNode') { nodeOutput.set(node.id, 'mesh'); continue }
if (node.type === 'textNode') { nodeOutput.set(node.id, 'text'); continue }
if (node.type === 'extensionNode') {
const ext = getWorkflowExtension(node.data.extensionId ?? '', allExtensions)
if (ext) nodeOutput.set(node.id, ext.output)
}
}
// For each extension node, check that every incoming edge carries an accepted type
const extNodes = workflow.nodes.filter((n) => n.type === 'extensionNode')
for (const node of extNodes) {
const ext = getWorkflowExtension(node.data.extensionId ?? '', allExtensions)
if (!ext) continue
if (prev !== ext.input) return true
prev = ext.output
const accepted = ext.inputs ?? [ext.input]
for (const edge of workflow.edges) {
if (edge.target !== node.id) continue
const srcType = nodeOutput.get(edge.source)
if (srcType && !accepted.includes(srcType as any)) return true
}
}
return false
}, [workflow, allExtensions])
Expand Down
69 changes: 53 additions & 16 deletions src/areas/workflows/WorkflowsPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import type { Workflow, WFNode, WFEdge, WFNodeData } from '@shared/types/electro
import { buildAllWorkflowExtensions, getWorkflowExtension } from './mockExtensions'
import type { WorkflowExtension } from './mockExtensions'
import { useWorkflowRunStore } from './workflowRunStore'
import ExtensionNode from './nodes/ExtensionNode'
import ImageNode from './nodes/ImageNode'
import TextNode from './nodes/TextNode'
import AddToSceneNode from './nodes/AddToSceneNode'
import Load3DMeshNode from './nodes/Load3DMeshNode'
import WorkflowEdge from './nodes/WorkflowEdge'
import ExtensionNode from './nodes/ExtensionNode'
import ImageNode from './nodes/ImageNode'
import TextNode from './nodes/TextNode'
import AddToSceneNode from './nodes/AddToSceneNode'
import Load3DMeshNode from './nodes/Load3DMeshNode'
import PreviewImageNode from './nodes/PreviewImageNode'
import WorkflowEdge from './nodes/WorkflowEdge'

// ─── Constants ────────────────────────────────────────────────────────────────

const DRAG_KEY = 'modly/extension-id'
const DRAG_NODE_KEY = 'modly/node-type'
const NODE_TYPES = { extensionNode: ExtensionNode, imageNode: ImageNode, textNode: TextNode, outputNode: AddToSceneNode, meshNode: Load3DMeshNode }
const NODE_TYPES = { extensionNode: ExtensionNode, imageNode: ImageNode, textNode: TextNode, outputNode: AddToSceneNode, meshNode: Load3DMeshNode, previewNode: PreviewImageNode }
const EDGE_TYPES = { workflowEdge: WorkflowEdge }

const DEFAULT_EDGE_OPTS = { type: 'workflowEdge' }
Expand Down Expand Up @@ -149,10 +150,11 @@ const PANEL_MIN = 240
const PANEL_MAX = 860

const PANEL_BUILTIN_NODES = [
{ type: 'imageNode', label: 'Image', color: '#38bdf8', icon: <><rect x="3" y="3" width="18" height="18" rx="2"/><circle cx="8.5" cy="8.5" r="1.5"/><polyline points="21 15 16 10 5 21"/></> },
{ type: 'textNode', label: 'Text', color: '#fbbf24', icon: <><path d="M17 6.1H3M21 12.1H3M15.1 18H3"/></> },
{ type: 'meshNode', label: 'Load 3D Mesh', color: '#a78bfa', icon: <><path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/></> },
{ type: 'outputNode', label: 'Add to Scene', color: '#a78bfa', icon: <><path d="M21 16V8a2 2 0 0 0-1-1.73l-7-4a2 2 0 0 0-2 0l-7 4A2 2 0 0 0 3 8v8a2 2 0 0 0 1 1.73l7 4a2 2 0 0 0 2 0l7-4A2 2 0 0 0 21 16z"/></> },
{ type: 'imageNode', label: 'Image', color: '#38bdf8', icon: <><rect x="3" y="3" width="18" height="18" rx="2"/><circle cx="8.5" cy="8.5" r="1.5"/><polyline points="21 15 16 10 5 21"/></> },
{ type: 'textNode', label: 'Text', color: '#fbbf24', icon: <><path d="M17 6.1H3M21 12.1H3M15.1 18H3"/></> },
{ type: 'meshNode', label: 'Load 3D Mesh', color: '#a78bfa', icon: <><path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/></> },
{ type: 'outputNode', label: 'Add to Scene', color: '#a78bfa', icon: <><path d="M21 16V8a2 2 0 0 0-1-1.73l-7-4a2 2 0 0 0-2 0l-7 4A2 2 0 0 0 3 8v8a2 2 0 0 0 1 1.73l7 4a2 2 0 0 0 2 0l7-4A2 2 0 0 0 21 16z"/></> },
{ type: 'previewNode', label: 'Preview Views', color: '#38bdf8', icon: <><rect x="3" y="3" width="8" height="8" rx="1"/><rect x="13" y="3" width="8" height="8" rx="1"/><rect x="3" y="13" width="8" height="8" rx="1"/><rect x="13" y="13" width="8" height="8" rx="1"/></> },
]

function ExtGroupHeader({ title, author, expanded, onToggle, count }: { title: string; author?: string; expanded: boolean; onToggle: () => void; count: number }) {
Expand Down Expand Up @@ -391,10 +393,11 @@ function PanelToggleIcon({ open }: { open: boolean }) {
// ─── Node palette (Space to open) ────────────────────────────────────────────

const BUILTIN_NODES = [
{ type: 'imageNode', label: 'Image', color: '#38bdf8', description: 'Image input' },
{ type: 'textNode', label: 'Text', color: '#fbbf24', description: 'Text input' },
{ type: 'meshNode', label: 'Load 3D Mesh', color: '#a78bfa', description: 'Load a 3D mesh file or use current model' },
{ type: 'outputNode', label: 'Add to Scene', color: '#a78bfa', description: 'Output node' },
{ type: 'imageNode', label: 'Image', color: '#38bdf8', description: 'Image input' },
{ type: 'textNode', label: 'Text', color: '#fbbf24', description: 'Text input' },
{ type: 'meshNode', label: 'Load 3D Mesh', color: '#a78bfa', description: 'Load a 3D mesh file or use current model' },
{ type: 'outputNode', label: 'Add to Scene', color: '#a78bfa', description: 'Output node — adds the mesh to the 3D scene' },
{ type: 'previewNode', label: 'Preview Views', color: '#38bdf8', description: 'Displays multi-view image outputs in a 2×3 grid' },
]

type PaletteItem =
Expand Down Expand Up @@ -732,6 +735,32 @@ function HelpModal({ onClose }: { onClose: () => void }) {
)
}

// ─── Connection type helpers ──────────────────────────────────────────────────

function getNodeOutputType(node: Node | undefined, allExts: WorkflowExtension[]): string | undefined {
if (!node) return undefined
if (node.type === 'imageNode') return 'image'
if (node.type === 'meshNode') return 'mesh'
if (node.type === 'textNode') return 'text'
return allExts.find((e) => e.id === (node.data as WFNodeData)?.extensionId)?.output
}

function getNodeInputType(
node: Node | undefined,
targetHandle: string | null | undefined,
allExts: WorkflowExtension[],
): string | undefined {
if (!node) return undefined
if (node.type === 'outputNode') return 'mesh'
if (node.type === 'previewNode') return 'image'
const ext = allExts.find((e) => e.id === (node.data as WFNodeData)?.extensionId)
if (ext?.inputs && ext.inputs.length > 1 && targetHandle) {
const idx = parseInt(targetHandle.replace('input-', ''), 10)
return ext.inputs[isNaN(idx) ? 0 : idx] ?? ext.input
}
return ext?.input
}

// ─── Workflow canvas (inner, requires ReactFlowProvider) ──────────────────────

function WorkflowCanvasInner({
Expand All @@ -747,7 +776,7 @@ function WorkflowCanvasInner({
onNew: () => void
onImport: () => void
}) {
const { screenToFlowPosition, updateNodeData } = useReactFlow()
const { screenToFlowPosition, updateNodeData, getNode } = useReactFlow()
const { runState, run: runWorkflow, cancel } = useWorkflowRunStore()
const isRunning = runState.status === 'running'

Expand Down Expand Up @@ -838,6 +867,13 @@ function WorkflowCanvasInner({
const canUndo = histIdx > 0
const canRedo = histIdx < historyRef.current.length - 1

const isValidConnection = useCallback((connection: Connection) => {
const srcType = getNodeOutputType(getNode(connection.source) as Node, allExtensions)
const tgtType = getNodeInputType(getNode(connection.target) as Node, connection.targetHandle, allExtensions)
if (!srcType || !tgtType) return true // unknown type — allow
return srcType === tgtType
}, [getNode, allExtensions])

const onConnectStart = useCallback((_: React.MouseEvent | React.TouchEvent, params: OnConnectStartParams) => {
pendingConnectionRef.current = params
connectionCompletedRef.current = false
Expand Down Expand Up @@ -1133,6 +1169,7 @@ function WorkflowCanvasInner({
onEdgesChange={onEdgesChange}
onConnectStart={onConnectStart}
onConnect={onConnect}
isValidConnection={isValidConnection}
onConnectEnd={onConnectEnd}
onEdgeContextMenu={(e, edge) => { e.preventDefault(); setEdges((eds) => eds.filter((ed) => ed.id !== edge.id)) }}
defaultEdgeOptions={DEFAULT_EDGE_OPTS}
Expand Down
3 changes: 3 additions & 0 deletions src/areas/workflows/mockExtensions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export interface WorkflowExtension {
name: string
description: string
input: 'image' | 'text' | 'mesh'
inputs?: ('image' | 'text' | 'mesh')[] // multi-input; overrides input when set
output: 'image' | 'text' | 'mesh'
params: ParamSchema[]
builtin: boolean
Expand All @@ -34,6 +35,7 @@ export function buildAllWorkflowExtensions(
name: node.name,
description: ext.description ?? '',
input: node.input,
inputs: node.inputs,
output: node.output,
params: node.paramsSchema as ParamSchema[],
builtin: ext.builtin,
Expand All @@ -53,6 +55,7 @@ export function buildAllWorkflowExtensions(
name: node.name,
description: ext.description ?? '',
input: node.input,
inputs: node.inputs,
output: node.output,
params: node.paramsSchema as ParamSchema[],
builtin: ext.builtin,
Expand Down
Loading