Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions api/routers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,30 @@ async def unload_model(model_id: str):


@router.get("/hf-download")
async def hf_download(repo_id: str, model_id: str):
async def hf_download(repo_id: str, model_id: str, skip_prefixes: Optional[str] = None):
"""
Streams a HuggingFace Hub model download via SSE.
Downloads into MODELS_DIR / model_id applying the filtering
declared in the extension manifest (hf_skip_prefixes).

skip_prefixes: JSON-encoded list of path prefixes to exclude (passed from Electron).
Falls back to registry manifest if not provided.

SSE format: data: {"percent": 0-100, "file": "...", "status": "..."}
"""
import json as _json
dest_dir = str(MODELS_DIR / model_id)
try:
skip_list = generator_registry.get_manifest(model_id).get("hf_skip_prefixes", [])
except KeyError:
skip_list = []
# Prefer skip_prefixes passed directly from the client (authoritative, no registry dep)
if skip_prefixes:
try:
skip_list = _json.loads(skip_prefixes)
except Exception:
skip_list = []
else:
try:
skip_list = generator_registry.get_manifest(model_id).get("hf_skip_prefixes", [])
except KeyError:
skip_list = []

async def stream():
loop = asyncio.get_running_loop()
Expand Down
10 changes: 5 additions & 5 deletions electron/main/ipc-handlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ export function setupIpcHandlers(pythonBridge: PythonBridge, getWindow: WindowGe
return isModelDownloaded(modelsDir, modelId)
})

ipcMain.handle('model:download', async (event, { repoId, modelId }: { repoId: string; modelId: string }) => {
ipcMain.handle('model:download', async (event, { repoId, modelId, skipPrefixes }: { repoId: string; modelId: string; skipPrefixes?: string[] }) => {
try {
await downloadModelFromHF(repoId, modelId, (progress) => {
event.sender.send('model:downloadProgress', { modelId, ...progress })
})
}, skipPrefixes)
return { success: true }
} catch (err) {
return { success: false, error: String(err) }
Expand Down Expand Up @@ -381,15 +381,15 @@ export function setupIpcHandlers(pythonBridge: PythonBridge, getWindow: WindowGe
description?: string; author?: string | { name?: string }
hf_repo?: string; source?: string; generator_class?: string
model?: { repoId?: string; modelId?: string }
models?: { id?: string; name?: string; hf_repo?: string; description?: string }[]
models?: { id?: string; name?: string; hf_repo?: string; description?: string; hf_skip_prefixes?: string[] }[]
}

function parseExtensionManifest(parsed: ParsedManifest, fallbackId: string, trustedRepos: Set<string>) {
let models: { id: string; name: string; repoId: string; description?: string }[] = []
let models: { id: string; name: string; repoId: string; description?: string; hfSkipPrefixes?: string[] }[] = []
if (parsed.models?.length) {
models = parsed.models
.filter(v => v.hf_repo && v.id)
.map(v => ({ id: v.id!, name: v.name ?? v.id!, repoId: v.hf_repo!, description: v.description }))
.map(v => ({ id: v.id!, name: v.name ?? v.id!, repoId: v.hf_repo!, description: v.description, hfSkipPrefixes: v.hf_skip_prefixes }))
} else {
const repoId = parsed.model?.repoId ?? parsed.hf_repo
const modelId = parsed.model?.modelId ?? parsed.id ?? fallbackId
Expand Down
12 changes: 8 additions & 4 deletions electron/main/model-downloader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ export function listDownloadedModels(modelsDir: string): { id: string; name: str
* Reports progress (0–100) via the onProgress callback.
*/
export async function downloadModelFromHF(
repoId: string,
modelId: string,
onProgress: ProgressCallback
repoId: string,
modelId: string,
onProgress: ProgressCallback,
skipPrefixes?: string[],
): Promise<void> {
const { net } = require('electron')
const url = `${PYTHON_API_URL}/model/hf-download?repo_id=${encodeURIComponent(repoId)}&model_id=${encodeURIComponent(modelId)}`
let url = `${PYTHON_API_URL}/model/hf-download?repo_id=${encodeURIComponent(repoId)}&model_id=${encodeURIComponent(modelId)}`
if (skipPrefixes && skipPrefixes.length > 0) {
url += `&skip_prefixes=${encodeURIComponent(JSON.stringify(skipPrefixes))}`
}

const res = await net.fetch(url)
if (!res.ok) throw new Error(`HuggingFace download failed: HTTP ${res.status}`)
Expand Down
2 changes: 1 addition & 1 deletion electron/preload/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ contextBridge.exposeInMainWorld('electron', {
export: (args: { outputUrl: string; format: string }) => ipcRenderer.invoke('model:export', args),
listDownloaded: () => ipcRenderer.invoke('model:listDownloaded'),
isDownloaded: (modelId: string) => ipcRenderer.invoke('model:isDownloaded', modelId),
download: (repoId: string, modelId: string) => ipcRenderer.invoke('model:download', { repoId, modelId }),
download: (repoId: string, modelId: string, skipPrefixes?: string[]) => ipcRenderer.invoke('model:download', { repoId, modelId, skipPrefixes }),
delete: (modelId: string) => ipcRenderer.invoke('model:delete', modelId),
unloadAll: () => ipcRenderer.invoke('model:unloadAll'),
onProgress: (cb: (data: { modelId: string; percent: number; file?: string; fileIndex?: number; totalFiles?: number; status?: string }) => void) => {
Expand Down
25 changes: 18 additions & 7 deletions src/areas/models/ModelsPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react'
import { useAppStore } from '@shared/stores/appStore'
import { useNavStore } from '@shared/stores/navStore'
import { useExtensionsStore } from '@shared/stores/extensionsStore'
import { useApi } from '@shared/hooks/useApi'
import { ConfirmModal } from '@shared/components/ui'
import { LocalModel } from './models'
import { formatModelName } from './utils'
Expand All @@ -26,12 +27,15 @@ export default function ModelsPage(): JSX.Element {
const reloadExtensions = useExtensionsStore((s) => s.reload)
const clearInstall = useExtensionsStore((s) => s.clearInstallState)

const { getAllModelsStatus } = useApi()

// HF models state
const [models, setModels] = useState<LocalModel[]>([])
const [downloading, setDownloading] = useState<Record<string, { percent: number; file?: string; fileIndex?: number; totalFiles?: number }>>({})
const [deleteTarget, setDeleteTarget] = useState<LocalModel | null>(null)
const [deleteError, setDeleteError] = useState<string | null>(null)
const [uninstallTarget, setUninstallTarget] = useState<string | null>(null)
const [models, setModels] = useState<LocalModel[]>([])
const [installedVariantIds, setInstalledVariantIds] = useState<string[]>([])
const [downloading, setDownloading] = useState<Record<string, { percent: number; file?: string; fileIndex?: number; totalFiles?: number }>>({})
const [deleteTarget, setDeleteTarget] = useState<LocalModel | null>(null)
const [deleteError, setDeleteError] = useState<string | null>(null)
const [uninstallTarget, setUninstallTarget] = useState<string | null>(null)

// GitHub extension install form
const [showGHForm, setShowGHForm] = useState(false)
Expand All @@ -43,6 +47,13 @@ export default function ModelsPage(): JSX.Element {
async function refresh() {
const list = await window.electron.model.listDownloaded()
setModels(list)
try {
const statuses = await getAllModelsStatus()
setInstalledVariantIds(statuses.filter((s) => s.downloaded).map((s) => s.id))
} catch {
// fallback: derive from directory list
setInstalledVariantIds(list.map((m) => m.id))
}
}

useEffect(() => {
Expand Down Expand Up @@ -261,7 +272,7 @@ export default function ModelsPage(): JSX.Element {
<ExtensionCard
key={ext.id}
ext={ext}
installedIds={models.map((m) => m.id)}
installedIds={installedVariantIds}
downloading={downloading}
disabled={isBusy}
loadError={
Expand All @@ -270,7 +281,7 @@ export default function ModelsPage(): JSX.Element {
}
onInstall={(variant: ExtensionVariant) => {
setDownloading((prev) => ({ ...prev, [variant.id]: { percent: 0 } }))
window.electron.model.download(variant.repoId, variant.id).then((result) => {
window.electron.model.download(variant.repoId, variant.id, variant.hfSkipPrefixes).then((result) => {
if (!result.success) {
setDownloading((prev) => { const n = { ...prev }; delete n[variant.id]; return n })
}
Expand Down
7 changes: 7 additions & 0 deletions src/shared/hooks/useApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export function useApi() {
options: GenerationOptions,
collection: string = 'Default',
imageData?: string,
signal?: AbortSignal,
): Promise<{ jobId: string }> {
// Use provided base64 (drag & drop) or read from disk via IPC
const base64 = imageData ?? await window.electron.fs.readFileBase64(imagePath)
Expand All @@ -32,6 +33,7 @@ export function useApi() {
formData.append('num_inference_steps', String(options.numInferenceSteps))
const { data } = await client.post<{ job_id: string }>('/generate/from-image', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
signal,
})

return { jobId: data.job_id }
Expand All @@ -58,6 +60,11 @@ export function useApi() {
return data
}

async function getAllModelsStatus(): Promise<{ id: string; downloaded: boolean }[]> {
const { data } = await client.get('/model/all')
return data
}

async function downloadModel(
onProgress?: (pct: number) => void
): Promise<void> {
Expand Down
11 changes: 9 additions & 2 deletions src/shared/hooks/useGeneration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ export function useGeneration() {
const activeCollectionId = useCollectionsStore((s) => s.activeCollectionId)
const { generateFromImage, pollJobStatus, cancelJob } = useApi()
const cancelledRef = useRef(false)
const abortControllerRef = useRef<AbortController | null>(null)

const startGeneration = useCallback(
async (imagePath: string) => {
cancelledRef.current = false
abortControllerRef.current = new AbortController()
const job = {
id: crypto.randomUUID(),
imageFile: imagePath,
Expand All @@ -25,18 +27,22 @@ export function useGeneration() {
setCurrentJob(job)

try {
const { jobId } = await generateFromImage(imagePath, generationOptions, activeCollectionId, selectedImageData ?? undefined)
const { jobId } = await generateFromImage(imagePath, generationOptions, activeCollectionId, selectedImageData ?? undefined, abortControllerRef.current.signal)

if (cancelledRef.current) {
await cancelJob(jobId)
setCurrentJob(null)
return
}

updateCurrentJob({ status: 'generating', progress: 0 })

await pollUntilDone(jobId)
} catch (err) {
if (cancelledRef.current) return
if (cancelledRef.current) {
setCurrentJob(null)
return
}
updateCurrentJob({
status: 'error',
error: err instanceof Error ? err.message : String(err)
Expand Down Expand Up @@ -84,6 +90,7 @@ export function useGeneration() {

const cancelGeneration = useCallback(() => {
cancelledRef.current = true
abortControllerRef.current?.abort()
}, [])

const reset = useCallback(() => setCurrentJob(null), [setCurrentJob])
Expand Down
9 changes: 5 additions & 4 deletions src/shared/stores/extensionsStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import { create } from 'zustand'
// ─── Types ────────────────────────────────────────────────────────────────────

export interface ExtensionVariant {
id: string
name: string
repoId: string
description?: string
id: string
name: string
repoId: string
description?: string
hfSkipPrefixes?: string[]
}

export interface Extension {
Expand Down
2 changes: 1 addition & 1 deletion src/shared/types/electron.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ declare global {
export: (args: { outputUrl: string; format: string }) => Promise<{ success: boolean; error?: string }>
listDownloaded: () => Promise<{ id: string; name: string; size_gb: number }[]>
isDownloaded: (modelId: string) => Promise<boolean>
download: (repoId: string, modelId: string) => Promise<{ success: boolean; error?: string }>
download: (repoId: string, modelId: string, skipPrefixes?: string[]) => Promise<{ success: boolean; error?: string }>
delete: (modelId: string) => Promise<{ success: boolean; error?: string }>
unloadAll: () => Promise<{ success: boolean; error?: string }>
onProgress: (cb: (data: { modelId: string; percent: number; file?: string; fileIndex?: number; totalFiles?: number; status?: string }) => void) => void
Expand Down