Skip to content

Commit 911bb84

Browse files
committed
audo-select model after download
1 parent 0d329dc commit 911bb84

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

apps/desktop/src/components/settings/ai/stt/configure.tsx

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ import { Icon } from "@iconify-icon/react";
22
import { useForm } from "@tanstack/react-form";
33
import { useQuery } from "@tanstack/react-query";
44
import { openPath } from "@tauri-apps/plugin-opener";
5-
import { useEffect, useState } from "react";
5+
import { useCallback, useEffect, useState } from "react";
66
import { Streamdown } from "streamdown";
77
import { useManager } from "tinytick/ui-react";
88

9-
import { commands as localSttCommands } from "@hypr/plugin-local-stt";
10-
import type { SupportedSttModel } from "@hypr/plugin-local-stt";
9+
import { commands as localSttCommands, type SupportedSttModel } from "@hypr/plugin-local-stt";
1110
import { Accordion, AccordionContent, AccordionItem, AccordionTrigger } from "@hypr/ui/components/ui/accordion";
1211
import { Button } from "@hypr/ui/components/ui/button";
1312
import { cn } from "@hypr/utils";
13+
import { useListener } from "../../../../contexts/listener";
1414
import * as main from "../../../../store/tinybase/main";
1515
import { aiProviderSchema } from "../../../../store/tinybase/main";
1616
import {
@@ -289,7 +289,15 @@ function LocalModelAction({
289289
}
290290

291291
function HyprProviderLocalRow({ model, displayName }: { model: SupportedSttModel; displayName: string }) {
292-
const { progress, isDownloaded, showProgress, handleDownload, handleCancel } = useLocalModelDownload(model);
292+
const handleSelectModel = useSafeSelectModel();
293+
294+
const {
295+
progress,
296+
isDownloaded,
297+
showProgress,
298+
handleDownload,
299+
handleCancel,
300+
} = useLocalModelDownload(model, handleSelectModel);
293301

294302
const handleOpen = () =>
295303
localSttCommands.modelsDir().then((result) => {
@@ -313,17 +321,17 @@ function HyprProviderLocalRow({ model, displayName }: { model: SupportedSttModel
313321
);
314322
}
315323

316-
function useLocalModelDownload(model: SupportedSttModel) {
324+
function useLocalModelDownload(
325+
model: SupportedSttModel,
326+
onDownloadComplete?: (model: SupportedSttModel) => void,
327+
) {
317328
const manager = useManager();
318329
const [progress, setProgress] = useState<number>(0);
319330
const [taskRunId, setTaskRunId] = useState<string | null>(null);
320331

321332
const isDownloaded = useQuery(sttModelQueries.isDownloaded(model));
322333
const isDownloading = useQuery(sttModelQueries.isDownloading(model));
323334

324-
const taskRunInfo = taskRunId && manager ? manager.getTaskRunInfo(taskRunId) : null;
325-
const isTaskRunning = taskRunInfo?.running ?? false;
326-
327335
useEffect(() => {
328336
registerDownloadProgressCallback(model, setProgress);
329337
return () => {
@@ -335,8 +343,9 @@ function useLocalModelDownload(model: SupportedSttModel) {
335343
if (isDownloaded.data && taskRunId) {
336344
setTaskRunId(null);
337345
setProgress(0);
346+
onDownloadComplete?.(model);
338347
}
339-
}, [isDownloaded.data, taskRunId]);
348+
}, [isDownloaded.data, taskRunId, onDownloadComplete]);
340349

341350
useEffect(() => {
342351
const isNotDownloading = !isDownloading.data;
@@ -367,8 +376,7 @@ function useLocalModelDownload(model: SupportedSttModel) {
367376
setProgress(0);
368377
};
369378

370-
const showProgress = !isDownloaded.data && taskRunId !== null
371-
&& (isDownloading.data || isTaskRunning);
379+
const showProgress = !isDownloaded.data && taskRunId !== null;
372380

373381
return {
374382
progress,
@@ -401,3 +409,23 @@ function ProviderContext({ providerId }: { providerId: ProviderId }) {
401409
</Streamdown>
402410
);
403411
}
412+
413+
function useSafeSelectModel() {
414+
const handleSelectModel = main.UI.useSetValueCallback(
415+
"current_stt_model",
416+
(model: SupportedSttModel) => model,
417+
[],
418+
main.STORE_ID,
419+
);
420+
421+
const active = useListener((state) => state.status === "running_active");
422+
423+
const handler = useCallback((model: SupportedSttModel) => {
424+
if (active) {
425+
return;
426+
}
427+
handleSelectModel(model);
428+
}, [active, handleSelectModel]);
429+
430+
return handler;
431+
}

0 commit comments

Comments
 (0)