Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions api/routers/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from typing import Dict
from fastapi import APIRouter, File, Form, UploadFile, HTTPException, BackgroundTasks
from services.generators.base import smooth_progress
from services.generators.base import smooth_progress, GenerationCancelled

import re as _re
from services.generator_registry import generator_registry, WORKSPACE_DIR
Expand All @@ -13,6 +13,8 @@
router = APIRouter(tags=["generation"])

_jobs: Dict[str, JobStatus] = {}
_cancelled: set = set()
_cancel_events: Dict[str, threading.Event] = {}


@router.post("/from-image")
Expand Down Expand Up @@ -71,6 +73,7 @@ async def generate_from_image(

job = JobStatus(job_id=job_id, status="pending", progress=0)
_jobs[job_id] = job
_cancel_events[job_id] = threading.Event()

background_tasks.add_task(_run_generation, job_id, image_bytes, params, collection)

Expand All @@ -86,6 +89,19 @@ async def job_status(job_id: str):
return job


@router.post("/cancel/{job_id}")
async def cancel_job(job_id: str):
job = _jobs.get(job_id)
if not job:
raise HTTPException(404, f"Job {job_id} not found")
_cancelled.add(job_id)
if job_id in _cancel_events:
_cancel_events[job_id].set()
if job.status in ("pending", "running"):
job.status = "cancelled"
return {"cancelled": True}


async def _run_generation(job_id: str, image_bytes: bytes, params: dict, collection: str = "Default") -> None:
job = _jobs[job_id]
job.status = "running"
Expand Down Expand Up @@ -118,20 +134,36 @@ def progress_cb(pct: int, step: str = "") -> None:
else:
gen = await loop.run_in_executor(None, generator_registry.get_active)

if job_id in _cancelled:
return

# Direct output to the collection subfolder
coll_dir = WORKSPACE_DIR / collection
coll_dir.mkdir(parents=True, exist_ok=True)
gen.outputs_dir = coll_dir

cancel_event = _cancel_events.get(job_id)
import inspect
supports_cancel = "cancel_event" in inspect.signature(gen.generate).parameters
output_path = await loop.run_in_executor(
None,
lambda: gen.generate(image_bytes, params, progress_cb),
lambda: gen.generate(image_bytes, params, progress_cb, cancel_event)
if supports_cancel
else gen.generate(image_bytes, params, progress_cb),
)

if job_id in _cancelled:
return

job.status = "done"
job.progress = 100
job.output_url = f"/workspace/{collection}/{output_path.name}"

except GenerationCancelled:
job.status = "cancelled"
except Exception as exc:
if job_id in _cancelled:
return
tb = traceback.format_exc()
print(f"[Generation ERROR] {exc}\n{tb}")
job.status = "error"
Expand Down
2 changes: 1 addition & 1 deletion api/schemas/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class JobStatus(BaseModel):
job_id: str
status: Literal["pending", "running", "done", "error"]
status: Literal["pending", "running", "done", "error", "cancelled"]
progress: int = 0 # 0–100
step: Optional[str] = None # Human-readable current step
output_url: Optional[str] = None
Expand Down
11 changes: 11 additions & 0 deletions api/services/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing import Callable, Optional


class GenerationCancelled(Exception):
"""Raised by generators when a cancel_event is set mid-generation."""


def smooth_progress(
progress_cb: Callable[[int, str], None],
start: int,
Expand Down Expand Up @@ -87,14 +91,21 @@ def generate(
image_bytes: bytes,
params: dict,
progress_cb: Optional[Callable[[int, str], None]] = None,
cancel_event: Optional[threading.Event] = None,
) -> Path:
"""
Starts 3D generation from an image.
Returns the path to the generated .glb file.
progress_cb(percent: int, step_label: str)
cancel_event: set this to interrupt generation between steps.
"""
...

def _check_cancelled(self, cancel_event: Optional[threading.Event]) -> None:
"""Raises GenerationCancelled if cancel_event is set."""
if cancel_event and cancel_event.is_set():
raise GenerationCancelled()

# ------------------------------------------------------------------ #
# Parameter schema (for the UI)
# ------------------------------------------------------------------ #
Expand Down
29 changes: 19 additions & 10 deletions src/areas/generate/GeneratePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Viewer3D from './components/Viewer3D'
export default function GeneratePage(): JSX.Element {
const selectedImagePath = useAppStore((s) => s.selectedImagePath)
const modelId = useAppStore((s) => s.generationOptions.modelId)
const { currentJob, startGeneration } = useGeneration()
const { currentJob, startGeneration, cancelGeneration } = useGeneration()
const isGenerating = currentJob?.status === 'uploading' || currentJob?.status === 'generating'

const [unloadStatus, setUnloadStatus] = useState<'idle' | 'done'>('idle')
Expand All @@ -33,16 +33,25 @@ export default function GeneratePage(): JSX.Element {
<GenerationOptions />
</div>

{/* Sticky bottom: Generate button */}
{/* Sticky bottom: Generate / Stop button */}
<div className="p-4 border-t border-zinc-800">
<button
onClick={() => canGenerate && startGeneration(selectedImagePath!)}
disabled={!canGenerate}
title={disabledReason}
className="w-full py-2.5 rounded-lg text-sm font-semibold bg-accent hover:bg-accent-dark disabled:opacity-40 disabled:cursor-not-allowed text-white transition-colors"
>
{isGenerating ? 'Generating…' : 'Generate 3D Model'}
</button>
{isGenerating ? (
<button
onClick={cancelGeneration}
className="w-full py-2.5 rounded-lg text-sm font-semibold bg-red-600 hover:bg-red-700 text-white transition-colors"
>
Stop
</button>
) : (
<button
onClick={() => canGenerate && startGeneration(selectedImagePath!)}
disabled={!canGenerate}
title={disabledReason}
className="w-full py-2.5 rounded-lg text-sm font-semibold bg-accent hover:bg-accent-dark disabled:opacity-40 disabled:cursor-not-allowed text-white transition-colors"
>
Generate 3D Model
</button>
)}
</div>
</div>

Expand Down
6 changes: 5 additions & 1 deletion src/shared/hooks/useApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,9 @@ export function useApi() {
return { url: data.url, faceCount: data.face_count }
}

return { generateFromImage, pollJobStatus, getModelStatus, downloadModel, optimizeMesh }
async function cancelJob(jobId: string): Promise<void> {
await client.post(`/generate/cancel/${jobId}`).catch(() => {})
}

return { generateFromImage, pollJobStatus, cancelJob, getModelStatus, downloadModel, optimizeMesh }
}
33 changes: 28 additions & 5 deletions src/shared/hooks/useGeneration.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useCallback } from 'react'
import { useCallback, useRef } from 'react'
import { useAppStore } from '@shared/stores/appStore'
import { useCollectionsStore } from '@shared/stores/collectionsStore'
import { useApi } from './useApi'
Expand All @@ -7,10 +7,12 @@ export function useGeneration() {
const { currentJob, setCurrentJob, updateCurrentJob, generationOptions, selectedImageData } = useAppStore()
const addToWorkspace = useCollectionsStore((s) => s.addToWorkspace)
const activeCollectionId = useCollectionsStore((s) => s.activeCollectionId)
const { generateFromImage, pollJobStatus } = useApi()
const { generateFromImage, pollJobStatus, cancelJob } = useApi()
const cancelledRef = useRef(false)

const startGeneration = useCallback(
async (imagePath: string) => {
cancelledRef.current = false
const job = {
id: crypto.randomUUID(),
imageFile: imagePath,
Expand All @@ -25,25 +27,42 @@ export function useGeneration() {
try {
const { jobId } = await generateFromImage(imagePath, generationOptions, activeCollectionId, selectedImageData ?? undefined)

if (cancelledRef.current) {
await cancelJob(jobId)
return
}

updateCurrentJob({ status: 'generating', progress: 0 })

// Poll until done
await pollUntilDone(jobId)
} catch (err) {
if (cancelledRef.current) return
updateCurrentJob({
status: 'error',
error: err instanceof Error ? err.message : String(err)
})
}
},
[generateFromImage, pollJobStatus, setCurrentJob, updateCurrentJob, addToWorkspace, activeCollectionId]
[generateFromImage, pollJobStatus, cancelJob, setCurrentJob, updateCurrentJob, addToWorkspace, activeCollectionId]
)

const pollUntilDone = async (jobId: string) => {
while (true) {
await new Promise((r) => setTimeout(r, 1000))

if (cancelledRef.current) {
await cancelJob(jobId)
setCurrentJob(null)
break
}

const result = await pollJobStatus(jobId)

if (result.status === 'cancelled') {
setCurrentJob(null)
break
}

if (result.status === 'done') {
updateCurrentJob({ status: 'done', progress: 100, outputUrl: result.outputUrl, originalOutputUrl: result.outputUrl })
const finalJob = useAppStore.getState().currentJob
Expand All @@ -63,7 +82,11 @@ export function useGeneration() {
}
}

const cancelGeneration = useCallback(() => {
cancelledRef.current = true
}, [])

const reset = useCallback(() => setCurrentJob(null), [setCurrentJob])

return { currentJob, startGeneration, reset }
return { currentJob, startGeneration, cancelGeneration, reset }
}