diff --git a/src/app/api/chat/[provider]/agentRuntime.ts b/src/app/api/chat/[provider]/agentRuntime.ts index 8d541a401d16..4f8e3a7abe7f 100644 --- a/src/app/api/chat/[provider]/agentRuntime.ts +++ b/src/app/api/chat/[provider]/agentRuntime.ts @@ -1,6 +1,7 @@ import { getServerConfig } from '@/config/server'; import { JWTPayload } from '@/const/auth'; import { + ChatCompetitionOptions, ChatStreamPayload, LobeAzureOpenAI, LobeBedrockAI, @@ -29,8 +30,8 @@ class AgentRuntime { this._runtime = runtime; } - async chat(payload: ChatStreamPayload) { - return this._runtime.chat(payload); + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { + return this._runtime.chat(payload, options); } static async initializeWithUserPayload( @@ -123,14 +124,14 @@ class AgentRuntime { const { ZHIPU_API_KEY } = getServerConfig(); const apiKey = apiKeyManager.pick(payload?.apiKey || ZHIPU_API_KEY); - return LobeZhipuAI.fromAPIKey(apiKey); + return LobeZhipuAI.fromAPIKey({ apiKey }); } private static initMoonshot(payload: JWTPayload) { const { MOONSHOT_API_KEY, MOONSHOT_PROXY_URL } = getServerConfig(); const apiKey = apiKeyManager.pick(payload?.apiKey || MOONSHOT_API_KEY); - return new LobeMoonshotAI(apiKey, MOONSHOT_PROXY_URL); + return new LobeMoonshotAI({ apiKey, baseURL: MOONSHOT_PROXY_URL }); } private static initGoogle(payload: JWTPayload) { @@ -158,16 +159,16 @@ class AgentRuntime { private static initOllama(payload: JWTPayload) { const { OLLAMA_PROXY_URL } = getServerConfig(); - const baseUrl = payload?.endpoint || OLLAMA_PROXY_URL; + const baseURL = payload?.endpoint || OLLAMA_PROXY_URL; - return new LobeOllamaAI(baseUrl); + return new LobeOllamaAI({ baseURL }); } private static initPerplexity(payload: JWTPayload) { const { PERPLEXITY_API_KEY } = getServerConfig(); const apiKey = apiKeyManager.pick(payload?.apiKey || PERPLEXITY_API_KEY); - return new LobePerplexityAI(apiKey); + return new LobePerplexityAI({ apiKey }); } } diff --git a/src/app/api/chat/[provider]/route.ts b/src/app/api/chat/[provider]/route.ts index 35ee9d474359..66b071936829 100644 --- a/src/app/api/chat/[provider]/route.ts +++ b/src/app/api/chat/[provider]/route.ts @@ -19,6 +19,7 @@ export const preferredRegion = getPreferredRegion(); export const POST = async (req: Request, { params }: { params: { provider: string } }) => { let agentRuntime: AgentRuntime; + const { provider } = params; // ============ 1. init chat model ============ // @@ -34,7 +35,7 @@ export const POST = async (req: Request, { params }: { params: { provider: strin checkAuthMethod(payload.accessCode, payload.apiKey, oauthAuthorized); const body = await req.clone().json(); - agentRuntime = await AgentRuntime.initializeWithUserPayload(params.provider, payload, { + agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, payload, { apiVersion: payload.azureApiVersion, model: body.model, useAzure: payload.useAzure, @@ -44,10 +45,7 @@ export const POST = async (req: Request, { params }: { params: { provider: strin const err = e as AgentInitErrorPayload; return createErrorResponse( (err.errorType || ChatErrorType.InternalServerError) as ILobeAgentRuntimeErrorType, - { - error: err.error || e, - provider: params.provider, - }, + { error: err.error || e, provider }, ); } diff --git a/src/libs/agent-runtime/BaseAI.ts b/src/libs/agent-runtime/BaseAI.ts index bb51c79fb26d..82c321ee1de7 100644 --- a/src/libs/agent-runtime/BaseAI.ts +++ b/src/libs/agent-runtime/BaseAI.ts @@ -1,9 +1,12 @@ import { StreamingTextResponse } from 'ai'; -import { ChatStreamPayload } from '@/types/openai/chat'; +import { ChatCompetitionOptions, ChatStreamPayload } from './types'; export interface LobeRuntimeAI { baseURL?: string; - chat(payload: ChatStreamPayload): Promise; + chat( + payload: ChatStreamPayload, + options?: ChatCompetitionOptions, + ): Promise; } diff --git a/src/libs/agent-runtime/moonshot/index.test.ts b/src/libs/agent-runtime/moonshot/index.test.ts new file mode 100644 index 000000000000..aaea9e35b4fe --- /dev/null +++ b/src/libs/agent-runtime/moonshot/index.test.ts @@ -0,0 +1,320 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { ChatStreamCallbacks } from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeMoonshotAI } from './index'; + +const provider = 'moonshot'; +const defaultBaseURL = 'https://api.moonshot.cn/v1'; +const bizErrorType = 'MoonshotBizError'; +const invalidErrorType = 'InvalidMoonshotAPIKey'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobeMoonshotAI; + +beforeEach(() => { + instance = new LobeMoonshotAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeMoonshotAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobeMoonshotAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeMoonshotAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + it('should return a StreamingTextResponse on successful API call', async () => { + // Arrange + const mockStream = new ReadableStream(); + const mockResponse = Promise.resolve(mockStream); + + (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); + + // Act + const result = await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // Assert + expect(result).toBeInstanceOf(Response); + }); + + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeMoonshotAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeMoonshotAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidMoonshotAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidMoonshotAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('LobeMoonshotAI chat with callback and headers', () => { + it('should handle callback and headers correctly', async () => { + // 模拟 chat.completions.create 方法返回一个可读流 + const mockCreateMethod = vi + .spyOn(instance['client'].chat.completions, 'create') + .mockResolvedValue( + new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', + object: 'chat.completion.chunk', + created: 1709125675, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_86156a94a0', + choices: [ + { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, + ], + }); + controller.close(); + }, + }) as any, + ); + + // 准备 callback 和 headers + const mockCallback: ChatStreamCallbacks = { + onStart: vi.fn(), + onToken: vi.fn(), + }; + const mockHeaders = { 'Custom-Header': 'TestValue' }; + + // 执行测试 + const result = await instance.chat( + { + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }, + { callback: mockCallback, headers: mockHeaders }, + ); + + // 验证 callback 被调用 + await result.text(); // 确保流被消费 + expect(mockCallback.onStart).toHaveBeenCalled(); + expect(mockCallback.onToken).toHaveBeenCalledWith('hello'); + + // 验证 headers 被正确传递 + expect(result.headers.get('Custom-Header')).toEqual('TestValue'); + + // 清理 + mockCreateMethod.mockRestore(); + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_MOONSHOT_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_MOONSHOT_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_MOONSHOT_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_MOONSHOT_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/moonshot/index.ts b/src/libs/agent-runtime/moonshot/index.ts index 07e36abe87e6..d7067af19860 100644 --- a/src/libs/agent-runtime/moonshot/index.ts +++ b/src/libs/agent-runtime/moonshot/index.ts @@ -1,44 +1,42 @@ import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI from 'openai'; +import OpenAI, { ClientOptions } from 'openai'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatStreamPayload, ModelProvider } from '../types'; +import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; import { handleOpenAIError } from '../utils/handleOpenAIError'; const DEFAULT_BASE_URL = 'https://api.moonshot.cn/v1'; export class LobeMoonshotAI implements LobeRuntimeAI { - private _llm: OpenAI; + private client: OpenAI; baseURL: string; - constructor(apiKey?: string, baseURL: string = DEFAULT_BASE_URL) { + constructor({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidMoonshotAPIKey); - this._llm = new OpenAI({ apiKey, baseURL }); - this.baseURL = this._llm.baseURL; + this.client = new OpenAI({ apiKey, baseURL, ...res }); + this.baseURL = this.client.baseURL; } - async chat(payload: ChatStreamPayload) { + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { try { - const response = await this._llm.chat.completions.create( + const response = await this.client.chat.completions.create( payload as unknown as OpenAI.ChatCompletionCreateParamsStreaming, ); + const [prod, debug] = response.tee(); - const stream = OpenAIStream(response); - - const [debug, returnStream] = stream.tee(); - - if (DEBUG_CHAT_COMPLETION) { - debugStream(debug).catch(console.error); + if (process.env.DEBUG_MOONSHOT_CHAT_COMPLETION === '1') { + debugStream(debug.toReadableStream()).catch(console.error); } - return new StreamingTextResponse(returnStream); + return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { + headers: options?.headers, + }); } catch (error) { let desensitizedEndpoint = this.baseURL; diff --git a/src/libs/agent-runtime/ollama/index.test.ts b/src/libs/agent-runtime/ollama/index.test.ts new file mode 100644 index 000000000000..be450c0b9e9d --- /dev/null +++ b/src/libs/agent-runtime/ollama/index.test.ts @@ -0,0 +1,320 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { ChatStreamCallbacks } from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeOllamaAI } from './index'; + +const provider = 'ollama'; +const defaultBaseURL = 'http://127.0.0.1:11434/v1'; +const bizErrorType = 'OllamaBizError'; +const invalidErrorType = 'InvalidOllamaArgs'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobeOllamaAI; + +beforeEach(() => { + instance = new LobeOllamaAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeOllamaAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobeOllamaAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeOllamaAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + it('should return a StreamingTextResponse on successful API call', async () => { + // Arrange + const mockStream = new ReadableStream(); + const mockResponse = Promise.resolve(mockStream); + + (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); + + // Act + const result = await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // Assert + expect(result).toBeInstanceOf(Response); + }); + + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeOllamaAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeOllamaAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidOllamaAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidOllamaAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('LobeOllamaAI chat with callback and headers', () => { + it('should handle callback and headers correctly', async () => { + // 模拟 chat.completions.create 方法返回一个可读流 + const mockCreateMethod = vi + .spyOn(instance['client'].chat.completions, 'create') + .mockResolvedValue( + new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', + object: 'chat.completion.chunk', + created: 1709125675, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_86156a94a0', + choices: [ + { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, + ], + }); + controller.close(); + }, + }) as any, + ); + + // 准备 callback 和 headers + const mockCallback: ChatStreamCallbacks = { + onStart: vi.fn(), + onToken: vi.fn(), + }; + const mockHeaders = { 'Custom-Header': 'TestValue' }; + + // 执行测试 + const result = await instance.chat( + { + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }, + { callback: mockCallback, headers: mockHeaders }, + ); + + // 验证 callback 被调用 + await result.text(); // 确保流被消费 + expect(mockCallback.onStart).toHaveBeenCalled(); + expect(mockCallback.onToken).toHaveBeenCalledWith('hello'); + + // 验证 headers 被正确传递 + expect(result.headers.get('Custom-Header')).toEqual('TestValue'); + + // 清理 + mockCreateMethod.mockRestore(); + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_OLLAMA_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_OLLAMA_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_OLLAMA_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_OLLAMA_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/ollama/index.ts b/src/libs/agent-runtime/ollama/index.ts index 531e4bfda5ea..c0ebd70ba05e 100644 --- a/src/libs/agent-runtime/ollama/index.ts +++ b/src/libs/agent-runtime/ollama/index.ts @@ -1,44 +1,42 @@ import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI from 'openai'; +import OpenAI, { ClientOptions } from 'openai'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatStreamPayload, ModelProvider } from '../types'; +import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; import { handleOpenAIError } from '../utils/handleOpenAIError'; const DEFAULT_BASE_URL = 'http://127.0.0.1:11434/v1'; export class LobeOllamaAI implements LobeRuntimeAI { - private _llm: OpenAI; + private client: OpenAI; baseURL: string; - constructor(baseURL: string = DEFAULT_BASE_URL) { + constructor({ apiKey = 'ollama', baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { if (!baseURL) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs); - this._llm = new OpenAI({ apiKey: 'ollama', baseURL }); + this.client = new OpenAI({ apiKey, baseURL, ...res }); this.baseURL = baseURL; } - async chat(payload: ChatStreamPayload) { + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { try { - const response = await this._llm.chat.completions.create( + const response = await this.client.chat.completions.create( payload as unknown as OpenAI.ChatCompletionCreateParamsStreaming, ); + const [prod, debug] = response.tee(); - const stream = OpenAIStream(response); - - const [debug, returnStream] = stream.tee(); - - if (DEBUG_CHAT_COMPLETION) { - debugStream(debug).catch(console.error); + if (process.env.DEBUG_OLLAMA_CHAT_COMPLETION === '1') { + debugStream(debug.toReadableStream()).catch(console.error); } - return new StreamingTextResponse(returnStream); + return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { + headers: options?.headers, + }); } catch (error) { let desensitizedEndpoint = this.baseURL; diff --git a/src/libs/agent-runtime/openai/index.test.ts b/src/libs/agent-runtime/openai/index.test.ts index f41d14417f42..5eb9138bfce5 100644 --- a/src/libs/agent-runtime/openai/index.test.ts +++ b/src/libs/agent-runtime/openai/index.test.ts @@ -1,19 +1,24 @@ +// @vitest-environment node import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +// 引入模块以便于对函数进行spy +import { ChatStreamCallbacks } from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; import { LobeOpenAI } from './index'; // Mock the console.error to avoid polluting test output vi.spyOn(console, 'error').mockImplementation(() => {}); -describe('LobeOpenAI chat', () => { - let openaiInstance: LobeOpenAI; +describe('LobeOpenAI', () => { + let instance: LobeOpenAI; beforeEach(() => { - openaiInstance = new LobeOpenAI({ apiKey: 'test', dangerouslyAllowBrowser: true }); + instance = new LobeOpenAI({ apiKey: 'test' }); // 使用 vi.spyOn 来模拟 chat.completions.create 方法 - vi.spyOn(openaiInstance['client'].chat.completions, 'create').mockResolvedValue( + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( new ReadableStream() as any, ); }); @@ -22,16 +27,64 @@ describe('LobeOpenAI chat', () => { vi.clearAllMocks(); }); + describe('init', () => { + it('should correctly initialize with Azure options', () => { + const baseURL = 'https://abc.com'; + const modelName = 'abc'; + const client = new LobeOpenAI({ + apiKey: 'test', + useAzure: true, + baseURL, + azureOptions: { + apiVersion: '2023-08-01-preview', + model: 'abc', + }, + }); + + expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName); + }); + + describe('initWithAzureOpenAI', () => { + it('should correctly initialize with Azure options', () => { + const baseURL = 'https://abc.com'; + const modelName = 'abc'; + const client = LobeOpenAI.initWithAzureOpenAI({ + apiKey: 'test', + useAzure: true, + baseURL, + azureOptions: { + apiVersion: '2023-08-01-preview', + model: 'abc', + }, + }); + + expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName); + }); + + it('should use default Azure options when not explicitly provided', () => { + const baseURL = 'https://abc.com'; + + const client = LobeOpenAI.initWithAzureOpenAI({ + apiKey: 'test', + useAzure: true, + baseURL, + }); + + expect(client.baseURL).toEqual(baseURL + '/openai/deployments/'); + }); + }); + }); + describe('chat', () => { it('should return a StreamingTextResponse on successful API call', async () => { // Arrange const mockStream = new ReadableStream(); const mockResponse = Promise.resolve(mockStream); - (openaiInstance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); + (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); // Act - const result = await openaiInstance.chat({ + const result = await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: 'text-davinci-003', temperature: 0, @@ -41,136 +94,240 @@ describe('LobeOpenAI chat', () => { expect(result).toBeInstanceOf(Response); }); - it('should return an openai error response when OpenAI.APIError is thrown', async () => { - // Arrange - const apiError = new OpenAI.APIError( - 400, - { - status: 400, - error: { - message: 'Bad Request', + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, }, - }, - 'Error message', - {}, - ); + 'Error message', + {}, + ); - vi.spyOn(openaiInstance['client'].chat.completions, 'create').mockRejectedValue(apiError); + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); - // Act - try { - await openaiInstance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'text-davinci-003', - temperature: 0, - }); - } catch (e) { - expect(e).toEqual({ - endpoint: 'https://api.openai.com/v1', - error: { - error: { message: 'Bad Request' }, - status: 400, + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.openai.com/v1', + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: 'OpenAIBizError', + provider: 'openai', + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeOpenAI({}); + } catch (e) { + expect(e).toEqual({ errorType: 'NoOpenAIAPIKey' }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', }, - errorType: 'OpenAIBizError', - provider: 'openai', + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.openai.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: 'OpenAIBizError', + provider: 'openai', + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeOpenAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', }); - } + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: 'OpenAIBizError', + provider: 'openai', + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.openai.com/v1', + errorType: 'AgentRuntimeError', + provider: 'openai', + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); }); - it('should return an cause response when OpenAI.APIError is thrown with cause', async () => { - // Arrange - const errorInfo = { - stack: 'abc', - cause: { - message: 'api is undefined', - }, - }; - const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + describe('LobeOpenAI chat with callback and headers', () => { + it('should handle callback and headers correctly', async () => { + // 模拟 chat.completions.create 方法返回一个可读流 + const mockCreateMethod = vi + .spyOn(instance['client'].chat.completions, 'create') + .mockResolvedValue( + new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', + object: 'chat.completion.chunk', + created: 1709125675, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_86156a94a0', + choices: [ + { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, + ], + }); + controller.close(); + }, + }) as any, + ); - vi.spyOn(openaiInstance['client'].chat.completions, 'create').mockRejectedValue(apiError); + // 准备 callback 和 headers + const mockCallback: ChatStreamCallbacks = { + onStart: vi.fn(), + onToken: vi.fn(), + }; + const mockHeaders = { 'Custom-Header': 'TestValue' }; - // Act - try { - await openaiInstance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'text-davinci-003', - temperature: 0, - }); - } catch (e) { - expect(e).toEqual({ - endpoint: 'https://api.openai.com/v1', - error: { - cause: { message: 'api is undefined' }, - stack: 'abc', + // 执行测试 + const result = await instance.chat( + { + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, }, - errorType: 'OpenAIBizError', - provider: 'openai', - }); - } - }); + { callback: mockCallback, headers: mockHeaders }, + ); - it('should return an cause response with desensitize Url', async () => { - // Arrange - const errorInfo = { - stack: 'abc', - cause: { message: 'api is undefined' }, - }; - const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + // 验证 callback 被调用 + await result.text(); // 确保流被消费 + expect(mockCallback.onStart).toHaveBeenCalled(); + expect(mockCallback.onToken).toHaveBeenCalledWith('hello'); - openaiInstance = new LobeOpenAI({ - apiKey: 'test', - dangerouslyAllowBrowser: true, - baseURL: 'https://api.abc.com/v1', - }); + // 验证 headers 被正确传递 + expect(result.headers.get('Custom-Header')).toEqual('TestValue'); - vi.spyOn(openaiInstance['client'].chat.completions, 'create').mockRejectedValue(apiError); + // 清理 + mockCreateMethod.mockRestore(); + }); + }); - // Act - try { - await openaiInstance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'gpt-3.5-turbo', - temperature: 0, - }); - } catch (e) { - expect(e).toEqual({ - endpoint: 'https://api.***.com/v1', - error: { - cause: { message: 'api is undefined' }, - stack: 'abc', + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_OPENAI_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); }, - errorType: 'OpenAIBizError', - provider: 'openai', + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], }); - } - }); - it('should return a 500 error response for non-OpenAI errors', async () => { - // Arrange - const genericError = new Error('Generic Error'); + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_OPENAI_CHAT_COMPLETION; - vi.spyOn(openaiInstance['client'].chat.completions, 'create').mockRejectedValue(genericError); + // 模拟环境变量 + process.env.DEBUG_OPENAI_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); - // Act - try { - await openaiInstance.chat({ + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: 'text-davinci-003', temperature: 0, }); - } catch (e) { - expect(e).toEqual({ - endpoint: 'https://api.openai.com/v1', - errorType: 'AgentRuntimeError', - provider: 'openai', - error: { - name: genericError.name, - cause: genericError.cause, - message: genericError.message, - stack: genericError.stack, - }, - }); - } + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_OPENAI_CHAT_COMPLETION = originalDebugValue; + }); }); }); }); diff --git a/src/libs/agent-runtime/openai/index.ts b/src/libs/agent-runtime/openai/index.ts index 0e6384782dba..f3d6137b6fa9 100644 --- a/src/libs/agent-runtime/openai/index.ts +++ b/src/libs/agent-runtime/openai/index.ts @@ -6,26 +6,26 @@ import { ChatStreamPayload } from '@/types/openai/chat'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ModelProvider } from '../types'; +import { ChatCompetitionOptions, ModelProvider } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; import { handleOpenAIError } from '../utils/handleOpenAIError'; const DEFAULT_BASE_URL = 'https://api.openai.com/v1'; -interface AzureOpenAIOptions extends ClientOptions { +interface LobeOpenAIOptions extends ClientOptions { azureOptions?: { apiVersion?: string; model?: string; }; useAzure?: boolean; } + export class LobeOpenAI implements LobeRuntimeAI { private client: OpenAI; - constructor(options: AzureOpenAIOptions) { + constructor(options: LobeOpenAIOptions) { if (!options.apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.NoOpenAIAPIKey); if (options.useAzure) { @@ -39,7 +39,7 @@ export class LobeOpenAI implements LobeRuntimeAI { baseURL: string; - async chat(payload: ChatStreamPayload) { + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { // ============ 1. preprocess messages ============ // const { messages, ...params } = payload; @@ -55,15 +55,15 @@ export class LobeOpenAI implements LobeRuntimeAI { { headers: { Accept: '*/*' } }, ); - const stream = OpenAIStream(response); - - const [debug, prod] = stream.tee(); + const [prod, debug] = response.tee(); - if (DEBUG_CHAT_COMPLETION) { - debugStream(debug).catch(console.error); + if (process.env.DEBUG_OPENAI_CHAT_COMPLETION === '1') { + debugStream(debug.toReadableStream()).catch(console.error); } - return new StreamingTextResponse(prod); + return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { + headers: options?.headers, + }); } catch (error) { const { errorResult, RuntimeError } = handleOpenAIError(error); @@ -85,7 +85,7 @@ export class LobeOpenAI implements LobeRuntimeAI { } } - static initWithAzureOpenAI(options: AzureOpenAIOptions) { + static initWithAzureOpenAI(options: LobeOpenAIOptions) { const endpoint = options.baseURL!; const model = options.azureOptions?.model || ''; @@ -96,6 +96,7 @@ export class LobeOpenAI implements LobeRuntimeAI { const apiKey = options.apiKey!; const config: ClientOptions = { + ...options, apiKey, baseURL, defaultHeaders: { 'api-key': apiKey }, diff --git a/src/libs/agent-runtime/perplexity/index.test.ts b/src/libs/agent-runtime/perplexity/index.test.ts new file mode 100644 index 000000000000..21ed19654bd3 --- /dev/null +++ b/src/libs/agent-runtime/perplexity/index.test.ts @@ -0,0 +1,320 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { ChatStreamCallbacks } from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobePerplexityAI } from './index'; + +const provider = 'perplexity'; +const defaultBaseURL = 'https://api.perplexity.ai'; +const bizErrorType = 'PerplexityBizError'; +const invalidErrorType = 'InvalidPerplexityAPIKey'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobePerplexityAI; + +beforeEach(() => { + instance = new LobePerplexityAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobePerplexityAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobePerplexityAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobePerplexityAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + it('should return a StreamingTextResponse on successful API call', async () => { + // Arrange + const mockStream = new ReadableStream(); + const mockResponse = Promise.resolve(mockStream); + + (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); + + // Act + const result = await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // Assert + expect(result).toBeInstanceOf(Response); + }); + + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobePerplexityAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobePerplexityAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidMoonshotAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidMoonshotAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('LobePerplexityAI chat with callback and headers', () => { + it('should handle callback and headers correctly', async () => { + // 模拟 chat.completions.create 方法返回一个可读流 + const mockCreateMethod = vi + .spyOn(instance['client'].chat.completions, 'create') + .mockResolvedValue( + new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', + object: 'chat.completion.chunk', + created: 1709125675, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_86156a94a0', + choices: [ + { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, + ], + }); + controller.close(); + }, + }) as any, + ); + + // 准备 callback 和 headers + const mockCallback: ChatStreamCallbacks = { + onStart: vi.fn(), + onToken: vi.fn(), + }; + const mockHeaders = { 'Custom-Header': 'TestValue' }; + + // 执行测试 + const result = await instance.chat( + { + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }, + { callback: mockCallback, headers: mockHeaders }, + ); + + // 验证 callback 被调用 + await result.text(); // 确保流被消费 + expect(mockCallback.onStart).toHaveBeenCalled(); + expect(mockCallback.onToken).toHaveBeenCalledWith('hello'); + + // 验证 headers 被正确传递 + expect(result.headers.get('Custom-Header')).toEqual('TestValue'); + + // 清理 + mockCreateMethod.mockRestore(); + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_PERPLEXITY_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_PERPLEXITY_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_PERPLEXITY_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_PERPLEXITY_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/perplexity/index.ts b/src/libs/agent-runtime/perplexity/index.ts index 001a2a0cc076..f3fc1dd7933b 100644 --- a/src/libs/agent-runtime/perplexity/index.ts +++ b/src/libs/agent-runtime/perplexity/index.ts @@ -1,30 +1,29 @@ import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI from 'openai'; +import OpenAI, { ClientOptions } from 'openai'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatStreamPayload, ModelProvider } from '../types'; +import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; import { handleOpenAIError } from '../utils/handleOpenAIError'; const DEFAULT_BASE_URL = 'https://api.perplexity.ai'; export class LobePerplexityAI implements LobeRuntimeAI { - private _llm: OpenAI; + private client: OpenAI; baseURL: string; - constructor(apiKey?: string, baseURL: string = DEFAULT_BASE_URL) { + constructor({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidPerplexityAPIKey); - this._llm = new OpenAI({ apiKey, baseURL }); - this.baseURL = this._llm.baseURL; + this.client = new OpenAI({ apiKey, baseURL, ...res }); + this.baseURL = this.client.baseURL; } - async chat(payload: ChatStreamPayload) { + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { try { // Set a default frequency penalty value greater than 0 const defaultFrequencyPenalty = 0.1; @@ -32,19 +31,18 @@ export class LobePerplexityAI implements LobeRuntimeAI { ...payload, frequency_penalty: payload.frequency_penalty || defaultFrequencyPenalty, }; - const response = await this._llm.chat.completions.create( + const response = await this.client.chat.completions.create( chatPayload as unknown as OpenAI.ChatCompletionCreateParamsStreaming, ); + const [prod, debug] = response.tee(); - const stream = OpenAIStream(response); - - const [debug, returnStream] = stream.tee(); - - if (DEBUG_CHAT_COMPLETION) { - debugStream(debug).catch(console.error); + if (process.env.DEBUG_PERPLEXITY_CHAT_COMPLETION === '1') { + debugStream(debug.toReadableStream()).catch(console.error); } - return new StreamingTextResponse(returnStream); + return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { + headers: options?.headers, + }); } catch (error) { let desensitizedEndpoint = this.baseURL; diff --git a/src/libs/agent-runtime/types/chat.ts b/src/libs/agent-runtime/types/chat.ts index c8dea0d0c098..d66d9a35920d 100644 --- a/src/libs/agent-runtime/types/chat.ts +++ b/src/libs/agent-runtime/types/chat.ts @@ -1,3 +1,5 @@ +import { OpenAIStreamCallbacks } from 'ai'; + export type LLMRoleType = 'user' | 'system' | 'assistant' | 'function'; interface UserMessageContentPartText { @@ -86,6 +88,11 @@ export interface ChatStreamPayload { top_p?: number; } +export interface ChatCompetitionOptions { + callback: ChatStreamCallbacks; + headers?: Record; +} + export interface ChatCompletionFunctions { /** * The description of what the function does. @@ -117,3 +124,5 @@ export interface ChatCompletionTool { */ type: 'function'; } + +export type ChatStreamCallbacks = OpenAIStreamCallbacks; diff --git a/src/libs/agent-runtime/utils/debugStream.ts b/src/libs/agent-runtime/utils/debugStream.ts index 43e9fbe6b2c4..75dffe462ff1 100644 --- a/src/libs/agent-runtime/utils/debugStream.ts +++ b/src/libs/agent-runtime/utils/debugStream.ts @@ -8,7 +8,7 @@ export const debugStream = async (stream: ReadableStream) => { const { value, done: _done } = await reader.read(); const chunkValue = decoder.decode(value, { stream: true }); if (!_done) { - console.log(`chunk ${chunk}:`); + console.log(`[chunk ${chunk}]`); console.log(chunkValue); } diff --git a/src/libs/agent-runtime/utils/env.ts b/src/libs/agent-runtime/utils/env.ts index 38c7b59a85e1..8bd08056d3e0 100644 --- a/src/libs/agent-runtime/utils/env.ts +++ b/src/libs/agent-runtime/utils/env.ts @@ -1 +1,3 @@ export const DEBUG_CHAT_COMPLETION = process.env.DEBUG_CHAT_COMPLETION === '1'; +export const DEBUG_OPENAI_CHAT_COMPLETION = process.env.DEBUG_OPENAI_CHAT_COMPLETION === '1'; +export const DEBUG_ZHIPU_CHAT_COMPLETION = process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1'; diff --git a/src/libs/agent-runtime/zhipu/authToken.test.ts b/src/libs/agent-runtime/zhipu/authToken.test.ts new file mode 100644 index 000000000000..406b99b5da6b --- /dev/null +++ b/src/libs/agent-runtime/zhipu/authToken.test.ts @@ -0,0 +1,18 @@ +// @vitest-environment node +import { generateApiToken } from './authToken'; + +describe('generateApiToken', () => { + it('should throw an error if no apiKey is provided', async () => { + await expect(generateApiToken()).rejects.toThrow('Invalid apiKey'); + }); + + it('should throw an error if apiKey is invalid', async () => { + await expect(generateApiToken('invalid')).rejects.toThrow('Invalid apiKey'); + }); + + it('should return a token if a valid apiKey is provided', async () => { + const apiKey = 'id.secret'; + const token = await generateApiToken(apiKey); + expect(token).toBeDefined(); + }); +}); diff --git a/src/libs/agent-runtime/zhipu/index.test.ts b/src/libs/agent-runtime/zhipu/index.test.ts new file mode 100644 index 000000000000..b385085d1e00 --- /dev/null +++ b/src/libs/agent-runtime/zhipu/index.test.ts @@ -0,0 +1,322 @@ +// @vitest-environment node +import { StreamingTextResponse } from 'ai'; +import { OpenAI } from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { ChatStreamCallbacks, LobeOpenAI } from '@/libs/agent-runtime'; +import * as debugStreamModule from '@/libs/agent-runtime/utils/debugStream'; + +import { AgentRuntimeErrorType } from '../error'; +import { ModelProvider } from '../types'; +import * as authTokenModule from './authToken'; +import { LobeZhipuAI } from './index'; + +// Mock相关依赖 +vi.mock('./authToken'); + +describe('LobeZhipuAI', () => { + beforeEach(() => { + // Mock generateApiToken + vi.spyOn(authTokenModule, 'generateApiToken').mockResolvedValue('mocked_token'); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('fromAPIKey', () => { + it('should correctly initialize with an API key', async () => { + const lobeZhipuAI = await LobeZhipuAI.fromAPIKey({ apiKey: 'test_api_key' }); + expect(lobeZhipuAI).toBeInstanceOf(LobeZhipuAI); + expect(lobeZhipuAI.baseURL).toEqual('https://open.bigmodel.cn/api/paas/v4'); + }); + + it('should throw an error if API key is invalid', async () => { + vi.spyOn(authTokenModule, 'generateApiToken').mockRejectedValue(new Error('Invalid API Key')); + try { + await LobeZhipuAI.fromAPIKey({ apiKey: 'asd' }); + } catch (e) { + expect(e).toEqual({ errorType: 'InvalidZhipuAPIKey' }); + } + }); + }); + + describe('chat', () => { + let instance: LobeZhipuAI; + + beforeEach(async () => { + instance = await LobeZhipuAI.fromAPIKey({ + apiKey: 'test_api_key', + }); + + // Mock chat.completions.create + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); + }); + + it('should return a StreamingTextResponse on successful API call', async () => { + const result = await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'glm-4', + temperature: 0, + }); + expect(result).toBeInstanceOf(StreamingTextResponse); + }); + + it('should handle callback and headers correctly', async () => { + // 模拟 chat.completions.create 方法返回一个可读流 + const mockCreateMethod = vi + .spyOn(instance['client'].chat.completions, 'create') + .mockResolvedValue( + new ReadableStream({ + start(controller) { + controller.enqueue({ + id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO', + object: 'chat.completion.chunk', + created: 1709125675, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_86156a94a0', + choices: [ + { index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null }, + ], + }); + controller.close(); + }, + }) as any, + ); + + // 准备 callback 和 headers + const mockCallback: ChatStreamCallbacks = { + onStart: vi.fn(), + onToken: vi.fn(), + }; + const mockHeaders = { 'Custom-Header': 'TestValue' }; + + // 执行测试 + const result = await instance.chat( + { + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }, + { callback: mockCallback, headers: mockHeaders }, + ); + + // 验证 callback 被调用 + await result.text(); // 确保流被消费 + expect(mockCallback.onStart).toHaveBeenCalled(); + expect(mockCallback.onToken).toHaveBeenCalledWith('hello'); + + // 验证 headers 被正确传递 + expect(result.headers.get('Custom-Header')).toEqual('TestValue'); + + // 清理 + mockCreateMethod.mockRestore(); + }); + + it('should transform messages correctly', async () => { + const spyOn = vi.spyOn(instance['client'].chat.completions, 'create'); + + await instance.chat({ + messages: [ + { content: 'Hello', role: 'user' }, + { content: [{ type: 'text', text: 'Hello again' }], role: 'user' }, + ], + model: 'glm-4', + temperature: 0, + top_p: 1, + }); + + const calledWithParams = spyOn.mock.calls[0][0]; + + expect(calledWithParams.messages[1].content).toEqual([{ type: 'text', text: 'Hello again' }]); + expect(calledWithParams.temperature).toBeUndefined(); // temperature 0 should be undefined + expect((calledWithParams as any).do_sample).toBeTruthy(); // temperature 0 should be undefined + expect(calledWithParams.top_p).toEqual(0.99); // top_p should be transformed correctly + }); + + describe('Error', () => { + it('should return ZhipuAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://open.bigmodel.cn/api/paas/v4', + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: 'ZhipuBizError', + provider: 'zhipu', + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + await LobeZhipuAI.fromAPIKey({ apiKey: '' }); + } catch (e) { + expect(e).toEqual({ errorType: 'InvalidZhipuAPIKey' }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0.2, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://open.bigmodel.cn/api/paas/v4', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: 'ZhipuBizError', + provider: 'zhipu', + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = await LobeZhipuAI.fromAPIKey({ + apiKey: 'test', + + baseURL: 'https://abc.com/v2', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-3.5-turbo', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://***.com/v2', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: 'ZhipuBizError', + provider: 'zhipu', + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://open.bigmodel.cn/api/paas/v4', + errorType: 'AgentRuntimeError', + provider: 'zhipu', + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_OPENAI_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_ZHIPU_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_ZHIPU_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_ZHIPU_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/zhipu/index.ts b/src/libs/agent-runtime/zhipu/index.ts index decd7d950245..325de0b04e0b 100644 --- a/src/libs/agent-runtime/zhipu/index.ts +++ b/src/libs/agent-runtime/zhipu/index.ts @@ -1,13 +1,17 @@ import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI from 'openai'; +import OpenAI, { ClientOptions } from 'openai'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatStreamPayload, ModelProvider, OpenAIChatMessage } from '../types'; +import { + ChatCompetitionOptions, + ChatStreamPayload, + ModelProvider, + OpenAIChatMessage, +} from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; import { handleOpenAIError } from '../utils/handleOpenAIError'; import { parseDataUri } from '../utils/uriParser'; import { generateApiToken } from './authToken'; @@ -15,20 +19,22 @@ import { generateApiToken } from './authToken'; const DEFAULT_BASE_URL = 'https://open.bigmodel.cn/api/paas/v4'; export class LobeZhipuAI implements LobeRuntimeAI { - private _llm: OpenAI; + private client: OpenAI; baseURL: string; constructor(oai: OpenAI) { - this._llm = oai; - this.baseURL = this._llm.baseURL; + this.client = oai; + this.baseURL = this.client.baseURL; } - static async fromAPIKey(apiKey?: string, baseURL: string = DEFAULT_BASE_URL) { + static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { const invalidZhipuAPIKey = AgentRuntimeError.createError( AgentRuntimeErrorType.InvalidZhipuAPIKey, ); + if (!apiKey) throw invalidZhipuAPIKey; + let token: string; try { @@ -38,34 +44,32 @@ export class LobeZhipuAI implements LobeRuntimeAI { } const header = { Authorization: `Bearer ${token}` }; - - const llm = new OpenAI({ apiKey, baseURL, defaultHeaders: header }); + const llm = new OpenAI({ apiKey, baseURL, defaultHeaders: header, ...res }); return new LobeZhipuAI(llm); } - async chat(payload: ChatStreamPayload) { + async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { try { const params = this.buildCompletionsParams(payload); - const response = await this._llm.chat.completions.create( + const response = await this.client.chat.completions.create( params as unknown as OpenAI.ChatCompletionCreateParamsStreaming, ); - const stream = OpenAIStream(response); - - const [debug, returnStream] = stream.tee(); + const [prod, debug] = response.tee(); - if (DEBUG_CHAT_COMPLETION) { - debugStream(debug).catch(console.error); + if (process.env.DEBUG_ZHIPU_CHAT_COMPLETION === '1') { + debugStream(debug.toReadableStream()).catch(console.error); } - return new StreamingTextResponse(returnStream); + return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { + headers: options?.headers, + }); } catch (error) { const { errorResult, RuntimeError } = handleOpenAIError(error); const errorType = RuntimeError || AgentRuntimeErrorType.ZhipuBizError; - let desensitizedEndpoint = this.baseURL; if (this.baseURL !== DEFAULT_BASE_URL) { diff --git a/src/utils/jwt.test.ts b/src/utils/jwt.test.ts index 1aa7a43884de..084cc193080f 100644 --- a/src/utils/jwt.test.ts +++ b/src/utils/jwt.test.ts @@ -1,3 +1,4 @@ +// @vitest-environment node import { describe, expect, it } from 'vitest'; import { NON_HTTP_PREFIX } from '@/const/auth'; diff --git a/tests/setup.ts b/tests/setup.ts index 95d9ccfbc702..35c11afcf4cf 100644 --- a/tests/setup.ts +++ b/tests/setup.ts @@ -5,7 +5,19 @@ import { theme } from 'antd'; // refs: https://github.com/dumbmatter/fakeIndexedDB#dexie-and-other-indexeddb-api-wrappers import 'fake-indexeddb/auto'; import React from 'react'; -import 'vitest-canvas-mock'; + +if (typeof window !== 'undefined') { + // test with canvas + await import('vitest-canvas-mock'); +} else { + // test with polyfill crypto + const { Crypto } = await import('@peculiar/webcrypto'); + + Object.defineProperty(global, 'crypto', { + value: new Crypto(), + writable: true, + }); +} // remove antd hash on test theme.defaultConfig.hashed = false;