diff --git a/package.json b/package.json index 0cf00ad62f36..1c5261b22b69 100644 --- a/package.json +++ b/package.json @@ -145,6 +145,7 @@ "devDependencies": { "@commitlint/cli": "^18", "@ducanh2912/next-pwa": "^10", + "@edge-runtime/vm": "^3.2.0", "@lobehub/i18n-cli": "latest", "@lobehub/lint": "latest", "@next/bundle-analyzer": "^14", diff --git a/src/app/api/chat/[provider]/agentRuntime.test.ts b/src/app/api/chat/[provider]/agentRuntime.test.ts new file mode 100644 index 000000000000..5b9ba43ae365 --- /dev/null +++ b/src/app/api/chat/[provider]/agentRuntime.test.ts @@ -0,0 +1,268 @@ +// @vitest-environment edge-runtime +import { describe, expect, it, vi } from 'vitest'; + +import { JWTPayload } from '@/const/auth'; +import { + LobeAzureOpenAI, + LobeBedrockAI, + LobeGoogleAI, + LobeMoonshotAI, + LobeOllamaAI, + LobeOpenAI, + LobePerplexityAI, + LobeZhipuAI, + ModelProvider, +} from '@/libs/agent-runtime'; + +import AgentRuntime from './agentRuntime'; + +// 模拟依赖项 +vi.mock('@/config/server', () => ({ + getServerConfig: vi.fn(() => ({ + // 确保为每个provider提供必要的配置信息 + OPENAI_API_KEY: 'test-openai-key', + GOOGLE_API_KEY: 'test-google-key', + + AZURE_API_KEY: 'test-azure-key', + AZURE_ENDPOINT: 'endpoint', + + ZHIPU_API_KEY: 'test.zhipu-key', + MOONSHOT_API_KEY: 'test-moonshot-key', + AWS_SECRET_ACCESS_KEY: 'test-aws-secret', + AWS_ACCESS_KEY_ID: 'test-aws-id', + AWS_REGION: 'test-aws-region', + OLLAMA_PROXY_URL: 'test-ollama-url', + PERPLEXITY_API_KEY: 'test-perplexity-key', + })), +})); + +describe('AgentRuntime', () => { + describe('should initialize with various providers', () => { + describe('OpenAI provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.OpenAI, + jwtPayload, + ); + + expect(runtime).toBeInstanceOf(AgentRuntime); + expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); + expect(runtime['_runtime'].baseURL).toBe('user-endpoint'); + }); + }); + + describe('Azure OpenAI provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { + apiKey: 'user-azure-key', + endpoint: 'user-azure-endpoint', + useAzure: true, + }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.OpenAI, + jwtPayload, + ); + + expect(runtime).toBeInstanceOf(AgentRuntime); + expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); + expect(runtime['_runtime'].baseURL).toBe('user-azure-endpoint'); + }); + it('should initialize with azureOpenAIParams correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; + const azureOpenAIParams = { + apiVersion: 'custom-version', + model: 'custom-model', + useAzure: true, + }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.OpenAI, + jwtPayload, + azureOpenAIParams, + ); + + expect(runtime).toBeInstanceOf(AgentRuntime); + const openAIRuntime = runtime['_runtime'] as LobeOpenAI; + expect(openAIRuntime).toBeInstanceOf(LobeOpenAI); + }); + + it('should initialize with AzureAI correctly', async () => { + const jwtPayload: JWTPayload = { + apiKey: 'user-azure-key', + endpoint: 'user-azure-endpoint', + }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Azure, + jwtPayload, + ); + + expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI); + }); + it('should initialize AzureAI correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Azure, + jwtPayload, + ); + + expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI); + }); + }); + + describe('ZhiPu AI provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'zhipu.user-key' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.ZhiPu, + jwtPayload, + ); + + // 假设 LobeZhipuAI 是 ZhiPu 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeZhipuAI); + }); + it('should initialize correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.ZhiPu, + jwtPayload, + ); + + // 假设 LobeZhipuAI 是 ZhiPu 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeZhipuAI); + }); + }); + + describe('Google provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'user-google-key' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Google, + jwtPayload, + ); + + // 假设 LobeGoogleAI 是 Google 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeGoogleAI); + }); + + it('should initialize correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Google, + jwtPayload, + ); + + // 假设 LobeGoogleAI 是 Google 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeGoogleAI); + }); + }); + + describe('Moonshot AI provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'user-moonshot-key' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Moonshot, + jwtPayload, + ); + + // 假设 LobeMoonshotAI 是 Moonshot 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeMoonshotAI); + }); + it('should initialize correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Moonshot, + jwtPayload, + ); + + // 假设 LobeMoonshotAI 是 Moonshot 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeMoonshotAI); + }); + }); + + describe('Bedrock AI provider', () => { + it('should initialize correctly with payload apiKey', async () => { + const jwtPayload: JWTPayload = { + apiKey: 'user-bedrock-key', + awsAccessKeyId: 'user-aws-id', + awsSecretAccessKey: 'user-aws-secret', + awsRegion: 'user-aws-region', + }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Bedrock, + jwtPayload, + ); + + // 假设 LobeBedrockAI 是 Bedrock 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeBedrockAI); + }); + + it('should initialize correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Bedrock, + jwtPayload, + ); + + // 假设 LobeBedrockAI 是 Bedrock 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeBedrockAI); + }); + }); + + describe('Ollama provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { endpoint: 'user-ollama-url' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Ollama, + jwtPayload, + ); + + // 假设 LobeOllamaAI 是 Ollama 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeOllamaAI); + }); + + it('should initialize correctly without endpoint', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Ollama, + jwtPayload, + ); + + // 假设 LobeOllamaAI 是 Ollama 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobeOllamaAI); + }); + }); + + describe('Perplexity AI provider', () => { + it('should initialize correctly', async () => { + const jwtPayload: JWTPayload = { apiKey: 'user-perplexity-key' }; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Perplexity, + jwtPayload, + ); + + // 假设 LobePerplexityAI 是 Perplexity 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobePerplexityAI); + }); + + it('should initialize correctly without apiKey', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload( + ModelProvider.Perplexity, + jwtPayload, + ); + + // 假设 LobePerplexityAI 是 Perplexity 提供者的实现类 + expect(runtime['_runtime']).toBeInstanceOf(LobePerplexityAI); + }); + }); + + it('should handle unknown provider gracefully', async () => { + const jwtPayload: JWTPayload = {}; + const runtime = await AgentRuntime.initializeWithUserPayload('unknown', jwtPayload); + + // 根据实际实现,你可能需要检查是否返回了默认的 runtime 实例,或者是否抛出了异常 + // 例如,如果默认使用 OpenAI: + expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); + }); + }); +}); diff --git a/src/app/api/chat/[provider]/route.test.ts b/src/app/api/chat/[provider]/route.test.ts new file mode 100644 index 000000000000..048d0281dca7 --- /dev/null +++ b/src/app/api/chat/[provider]/route.test.ts @@ -0,0 +1,146 @@ +// @vitest-environment edge-runtime +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth'; +import { LobeRuntimeAI } from '@/libs/agent-runtime'; +import { ChatErrorType } from '@/types/fetch'; + +import { getJWTPayload } from '../auth'; +import AgentRuntime from './agentRuntime'; +import { POST } from './route'; + +vi.mock('../auth', () => ({ + getJWTPayload: vi.fn(), + checkAuthMethod: vi.fn(), +})); + +// 模拟请求和响应 +let request: Request; +beforeEach(() => { + request = new Request(new URL('https://test.com'), { + headers: { + [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token', + [OAUTH_AUTHORIZED]: 'true', + }, + method: 'POST', + body: JSON.stringify({ model: 'test-model' }), + }); +}); + +afterEach(() => { + // 清除模拟调用历史 + vi.clearAllMocks(); +}); + +describe('POST handler', () => { + describe(' init chat model', () => { + it('should initialize AgentRuntime correctly with valid authorization', async () => { + const mockParams = { provider: 'test-provider' }; + + // 设置 getJWTPayload 和 initializeWithUserPayload 的模拟返回值 + vi.mocked(getJWTPayload).mockResolvedValue({ + accessCode: 'test-access-code', + apiKey: 'test-api-key', + azureApiVersion: 'v1', + useAzure: true, + }); + + const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() }; + + const spy = vi + .spyOn(AgentRuntime, 'initializeWithUserPayload') + .mockResolvedValue(new AgentRuntime(mockRuntime)); + + // 调用 POST 函数 + await POST(request as unknown as Request, { params: mockParams }); + + // 验证是否正确调用了模拟函数 + expect(getJWTPayload).toHaveBeenCalledWith('Bearer some-valid-token'); + expect(spy).toHaveBeenCalledWith('test-provider', expect.anything(), { + apiVersion: 'v1', + model: 'test-model', + useAzure: true, + }); + }); + + it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => { + const mockParams = { provider: 'test-provider' }; + const requestWithoutAuthHeader = new Request(new URL('https://test.com'), { + method: 'POST', + body: JSON.stringify({ model: 'test-model' }), + }); + + const response = await POST(requestWithoutAuthHeader, { params: mockParams }); + + expect(response.status).toBe(401); + expect(await response.json()).toEqual({ + body: { + error: { errorType: 401 }, + provider: 'test-provider', + }, + errorType: 401, + }); + }); + it('should return InternalServerError error when throw a unknown error', async () => { + const mockParams = { provider: 'test-provider' }; + vi.mocked(getJWTPayload).mockRejectedValueOnce(new Error('unknown error')); + + const response = await POST(request, { params: mockParams }); + + expect(response.status).toBe(500); + expect(await response.json()).toEqual({ + body: { + error: {}, + provider: 'test-provider', + }, + errorType: 500, + }); + }); + }); + + describe('chat', () => { + it('should correctly handle chat completion with valid payload', async () => { + const mockParams = { provider: 'test-provider' }; + const mockChatPayload = { message: 'Hello, world!' }; + request = new Request(new URL('https://test.com'), { + headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' }, + method: 'POST', + body: JSON.stringify(mockChatPayload), + }); + + const mockChatResponse: any = { success: true, message: 'Reply from agent' }; + + vi.spyOn(AgentRuntime.prototype, 'chat').mockResolvedValue(mockChatResponse); + + const response = await POST(request as unknown as Request, { params: mockParams }); + + expect(response).toEqual(mockChatResponse); + expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload); + }); + + it('should return an error response when chat completion fails', async () => { + const mockParams = { provider: 'test-provider' }; + const mockChatPayload = { message: 'Hello, world!' }; + request = new Request(new URL('https://test.com'), { + headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' }, + method: 'POST', + body: JSON.stringify(mockChatPayload), + }); + + const mockErrorResponse = { + errorType: ChatErrorType.InternalServerError, + errorMessage: 'Something went wrong', + }; + + vi.spyOn(AgentRuntime.prototype, 'chat').mockRejectedValue(mockErrorResponse); + + const response = await POST(request, { params: mockParams }); + + expect(response.status).toBe(500); + expect(await response.json()).toEqual({ + body: { errorMessage: 'Something went wrong' }, + errorType: 500, + }); + }); + }); +}); diff --git a/src/app/api/chat/google/route.test.ts b/src/app/api/chat/google/route.test.ts new file mode 100644 index 000000000000..c9927bb32ad8 --- /dev/null +++ b/src/app/api/chat/google/route.test.ts @@ -0,0 +1,28 @@ +// @vitest-environment edge-runtime +import { describe, expect, it, vi } from 'vitest'; + +import { POST as UniverseRoute } from '../[provider]/route'; +import { POST, preferredRegion, runtime } from './route'; + +// 模拟 '../[provider]/route' +vi.mock('../[provider]/route', () => ({ + POST: vi.fn().mockResolvedValue('mocked response'), +})); + +describe('Configuration tests', () => { + it('should have runtime set to "edge"', () => { + expect(runtime).toBe('edge'); + }); + + it('should contain specific regions in preferredRegion', () => { + expect(preferredRegion).not.contain(['hk1']); + }); +}); + +describe('Google POST function tests', () => { + it('should call UniverseRoute with correct parameters', async () => { + const mockRequest = new Request('https://example.com', { method: 'POST' }); + await POST(mockRequest); + expect(UniverseRoute).toHaveBeenCalledWith(mockRequest, { params: { provider: 'google' } }); + }); +}); diff --git a/src/app/api/chat/google/route.ts b/src/app/api/chat/google/route.ts index 379dcab904c2..f81eb3fa02e7 100644 --- a/src/app/api/chat/google/route.ts +++ b/src/app/api/chat/google/route.ts @@ -1,18 +1,4 @@ -import { createErrorResponse } from '@/app/api/errorResponse'; -import { getServerConfig } from '@/config/server'; -import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth'; -import { - AgentInitErrorPayload, - AgentRuntimeError, - ChatCompletionErrorPayload, - ILobeAgentRuntimeErrorType, - LobeGoogleAI, -} from '@/libs/agent-runtime'; -import { ChatErrorType } from '@/types/fetch'; -import { ChatStreamPayload } from '@/types/openai/chat'; - -import apiKeyManager from '../apiKeyManager'; -import { checkAuthMethod, getJWTPayload } from '../auth'; +import { POST as UniverseRoute } from '../[provider]/route'; // due to the Chinese region does not support accessing Google // we need to use proxy to access it @@ -42,48 +28,4 @@ export const preferredRegion = [ 'syd1', ]; -export const POST = async (req: Request) => { - let agentRuntime: LobeGoogleAI; - - // ============ 1. init chat model ============ // - - try { - // get Authorization from header - const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER); - const oauthAuthorized = !!req.headers.get(OAUTH_AUTHORIZED); - - if (!authorization) throw AgentRuntimeError.createError(ChatErrorType.Unauthorized); - - // check the Auth With payload - const payload = await getJWTPayload(authorization); - checkAuthMethod(payload.accessCode, payload.apiKey, oauthAuthorized); - - const { GOOGLE_API_KEY } = getServerConfig(); - const apiKey = apiKeyManager.pick(payload?.apiKey || GOOGLE_API_KEY); - - agentRuntime = new LobeGoogleAI(apiKey); - } catch (e) { - // if catch the error, just return it - const err = e as AgentInitErrorPayload; - - return createErrorResponse(err.errorType as ILobeAgentRuntimeErrorType, { - error: err.error, - provider: 'google', - }); - } - - // ============ 2. create chat completion ============ // - - try { - const payload = (await req.json()) as ChatStreamPayload; - - return await agentRuntime.chat(payload); - } catch (e) { - const { errorType, provider, error: errorContent, ...res } = e as ChatCompletionErrorPayload; - - // track the error at server side - console.error(`Route: [${provider}] ${errorType}:`, errorContent); - - return createErrorResponse(errorType, { error: errorContent, provider, ...res }); - } -}; +export const POST = async (req: Request) => UniverseRoute(req, { params: { provider: 'google' } }); diff --git a/src/app/api/errorResponse.test.ts b/src/app/api/errorResponse.test.ts index 2155341659cd..9e930f30a20b 100644 --- a/src/app/api/errorResponse.test.ts +++ b/src/app/api/errorResponse.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from 'vitest'; +import { AgentRuntimeErrorType } from '@/libs/agent-runtime'; import { ChatErrorType } from '@/types/fetch'; import { createErrorResponse } from './errorResponse'; @@ -18,10 +19,97 @@ describe('createErrorResponse', () => { expect(response.status).toBe(401); }); - it('returns a 471 status for OpenAIBizError error type', () => { - const errorType = ChatErrorType.OpenAIBizError; + // 测试包含Invalid的错误类型 + it('returns a 401 status for Invalid error type', () => { + const errorType = 'InvalidTestError'; + const response = createErrorResponse(errorType as any); + expect(response.status).toBe(401); + }); + + it('returns a 403 status for LocationNotSupportError error type', () => { + const errorType = AgentRuntimeErrorType.LocationNotSupportError; const response = createErrorResponse(errorType); - expect(response.status).toBe(471); + expect(response.status).toBe(403); + }); + + describe('Provider Biz Error', () => { + it('returns a 471 status for OpenAIBizError error type', () => { + const errorType = ChatErrorType.OpenAIBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(471); + }); + + it('returns a 470 status for AgentRuntimeError error type', () => { + const errorType = AgentRuntimeErrorType.AgentRuntimeError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(470); + }); + + it('returns a 471 status for OpenAIBizError error type', () => { + const errorType = AgentRuntimeErrorType.OpenAIBizError; + const response = createErrorResponse(errorType as any); + expect(response.status).toBe(471); + }); + + // 测试 AzureBizError 错误类型返回472状态码 + it('returns a 472 status for AzureBizError error type', () => { + const errorType = AgentRuntimeErrorType.AzureBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(472); + }); + + // 测试 ZhipuBizError 错误类型返回473状态码 + it('returns a 473 status for ZhipuBizError error type', () => { + const errorType = AgentRuntimeErrorType.ZhipuBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(473); + }); + + // 测试 BedrockBizError 错误类型返回474状态码 + it('returns a 474 status for BedrockBizError error type', () => { + const errorType = AgentRuntimeErrorType.BedrockBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(474); + }); + + // 测试 GoogleBizError 错误类型返回475状态码 + it('returns a 475 status for GoogleBizError error type', () => { + const errorType = AgentRuntimeErrorType.GoogleBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(475); + }); + + // 测试 MoonshotBizError 错误类型返回476状态码 + it('returns a 476 status for MoonshotBizError error type', () => { + const errorType = AgentRuntimeErrorType.MoonshotBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(476); + }); + + // 测试 OllamaBizError 错误类型返回478状态码 + it('returns a 478 status for OllamaBizError error type', () => { + const errorType = AgentRuntimeErrorType.OllamaBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(478); + }); + + // 测试 PerplexityBizError 错误类型返回479状态码 + it('returns a 479 status for PerplexityBizError error type', () => { + const errorType = AgentRuntimeErrorType.PerplexityBizError; + const response = createErrorResponse(errorType); + expect(response.status).toBe(479); + }); + }); + + // 测试状态码不在200-599范围内的情况 + it('logs an error when the status code is not a number or not in the range of 200-599', () => { + const errorType = 'Unknown Error'; + const consoleSpy = vi.spyOn(console, 'error'); + try { + createErrorResponse(errorType as any); + } catch (e) {} + expect(consoleSpy).toHaveBeenCalled(); + consoleSpy.mockRestore(); }); // 测试默认情况 diff --git a/tests/setup.ts b/tests/setup.ts index 35c11afcf4cf..d9be8c711051 100644 --- a/tests/setup.ts +++ b/tests/setup.ts @@ -6,10 +6,19 @@ import { theme } from 'antd'; import 'fake-indexeddb/auto'; import React from 'react'; -if (typeof window !== 'undefined') { +// only inject in the dom environment +if ( + // not node runtime + typeof window !== 'undefined' && + // not edge runtime + typeof (globalThis as any).EdgeRuntime !== 'string' +) { // test with canvas await import('vitest-canvas-mock'); -} else { +} + +// node runtime +if (typeof window === 'undefined') { // test with polyfill crypto const { Crypto } = await import('@peculiar/webcrypto');