Skip to content

Commit

Permalink
🐛 fix: fix tts file saving in server mode (lobehub#3585)
Browse files Browse the repository at this point in the history
* 🐛 fix: fix tts in server mode

* 🐛 fix: fix tts not save to file

* ✅ test: fix test
  • Loading branch information
arvinxx authored Aug 24, 2024
1 parent 6b01996 commit ab1cb47
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 261 deletions.
5 changes: 3 additions & 2 deletions src/database/server/models/__tests__/message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ describe('MessageModel', () => {

// 断言结果
expect(result[0].extra.translate).toEqual({ content: 'translated', from: 'en', to: 'zh' });
// TODO: 确认是否需要包含 tts 字段
expect(result[0].extra.tts).toEqual({
// contentMd5: 'md5', file: 'f1', voice: 'voice1'
contentMd5: 'md5',
file: 'f1',
voice: 'voice1',
});
});

Expand Down
25 changes: 8 additions & 17 deletions src/database/server/models/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,9 @@ export class MessageModel {
},

ttsId: messageTTS.id,

// TODO: 确认下如何处理 TTS 的读取
// ttsContentMd5: messageTTS.contentMd5,
// ttsFile: messageTTS.fileId,
// ttsVoice: messageTTS.voice,
ttsContentMd5: messageTTS.contentMd5,
ttsFile: messageTTS.fileId,
ttsVoice: messageTTS.voice,
/* eslint-enable */
})
.from(messages)
Expand All @@ -113,7 +111,7 @@ export class MessageModel {

const messageIds = result.map((message) => message.id as string);

if (messageIds.length === 0) return result;
if (messageIds.length === 0) return [];

// 2. get relative files
const rawRelatedFileList = await serverDB
Expand Down Expand Up @@ -166,14 +164,7 @@ export class MessageModel {
.where(inArray(messageQueries.messageId, messageIds));

return result.map(
({
model,
provider,
translate,
ttsId,
// ttsFile, ttsId, ttsContentMd5, ttsVoice,
...item
}) => {
({ model, provider, translate, ttsId, ttsFile, ttsContentMd5, ttsVoice, ...item }) => {
const messageQuery = messageQueriesList.find((relation) => relation.messageId === item.id);
return {
...item,
Expand All @@ -185,9 +176,9 @@ export class MessageModel {
translate,
tts: ttsId
? {
// contentMd5: ttsContentMd5,
// file: ttsFile,
// voice: ttsVoice,
contentMd5: ttsContentMd5,
file: ttsFile,
voice: ttsVoice,
}
: undefined,
},
Expand Down
2 changes: 1 addition & 1 deletion src/server/routers/lambda/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ export const messageRouter = router({
value: z
.object({
contentMd5: z.string().optional(),
fileId: z.string().optional(),
file: z.string().optional(),
voice: z.string().optional(),
})
.or(z.literal(false)),
Expand Down
72 changes: 0 additions & 72 deletions src/services/__tests__/upload_legacy.test.ts

This file was deleted.

104 changes: 0 additions & 104 deletions src/services/upload_legacy.ts

This file was deleted.

36 changes: 2 additions & 34 deletions src/store/file/slices/tts/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,6 @@ describe('TTSFileAction', () => {
expect(fileService.removeFile).toHaveBeenCalledWith(fileId);
});

// Test for uploadTTSFile
it('uploadTTSFile should upload the file and return the file id', async () => {
const testFile = new File(['content'], 'test.mp3', { type: 'audio/mp3' });
const uploadedFileData = {
id: 'new-tts-file-id',
createdAt: testFile.lastModified,
data: await testFile.arrayBuffer(),
fileType: testFile.type,
name: testFile.name,
saveMode: 'local',
size: testFile.size,
};

// Mock the fileService.uploadFile to resolve with uploadedFileData
vi.spyOn(fileService, 'createFile').mockResolvedValue({ id: uploadedFileData.id, url: '' });

let fileId;
await act(async () => {
fileId = await useStore.getState().uploadTTSFile(testFile);
});

expect(fileService.createFile).toHaveBeenCalledWith({
createdAt: testFile.lastModified,
data: await testFile.arrayBuffer(),
fileType: testFile.type,
name: testFile.name,
saveMode: 'local',
size: testFile.size,
});
expect(fileId).toBe(uploadedFileData.id);
});

// Test for uploadTTSByArrayBuffers
it('uploadTTSByArrayBuffers should create a file and call uploadTTSFile', async () => {
const messageId = 'message-id';
Expand All @@ -93,8 +61,8 @@ describe('TTSFileAction', () => {

// Spy on uploadTTSFile to simulate a successful upload
const uploadTTSFileSpy = vi
.spyOn(useStore.getState(), 'uploadTTSFile')
.mockResolvedValue('new-tts-file-id');
.spyOn(useStore.getState(), 'uploadWithProgress')
.mockResolvedValue({ id: 'new-tts-file-id', url: '1' });

let fileId;
await act(async () => {
Expand Down
35 changes: 10 additions & 25 deletions src/store/file/slices/tts/action.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import useSWR, { SWRResponse } from 'swr';
import { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';

import { useClientDataSWR } from '@/libs/swr';
import { fileService } from '@/services/file';
import { legacyUploadService } from '@/services/upload_legacy';
import { FileItem } from '@/types/files';

import { FileStore } from '../../store';

const FETCH_TTS_FILE = 'fetchTTSFile';

export interface TTSFileAction {
removeTTSFile: (id: string) => Promise<void>;

Expand All @@ -15,8 +17,6 @@ export interface TTSFileAction {
arrayBuffers: ArrayBuffer[],
) => Promise<string | undefined>;

uploadTTSFile: (file: File) => Promise<string | undefined>;

useFetchTTSFile: (id: string | null) => SWRResponse<FileItem>;
}

Expand All @@ -38,26 +38,11 @@ export const createTTSFileSlice: StateCreator<
type: fileType,
};
const file = new File([blob], fileName, fileOptions);
return get().uploadTTSFile(file);
},
uploadTTSFile: async (file) => {
try {
const res = await legacyUploadService.uploadFile({
createdAt: file.lastModified,
data: await file.arrayBuffer(),
fileType: file.type,
name: file.name,
saveMode: 'local',
size: file.size,
});

const data = await fileService.createFile(res);

return data.id;
} catch (error) {
// 提示用户上传失败
console.error('upload error:', error);
}

const res = await get().uploadWithProgress({ file });

return res?.id;
},
useFetchTTSFile: (id) => useSWR(id, fileService.getFile),
useFetchTTSFile: (id) =>
useClientDataSWR(!!id ? [FETCH_TTS_FILE, id] : null, () => fileService.getFile(id!)),
});
Loading

0 comments on commit ab1cb47

Please sign in to comment.