Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: refactor the core chatStream #1426

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
✅ test: add google tests
  • Loading branch information
arvinxx committed Feb 29, 2024
commit 4b44fba20bf5a4687b9db0118a3b1fd85c68092d
4 changes: 2 additions & 2 deletions src/app/api/chat/google/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import { POST as UniverseRoute } from '../[provider]/route';
//
// setGlobalDispatcher(new ProxyAgent({ uri: process.env.HTTP_PROXY_URL }));
// }
// undici only can be used in NodeJS
// export const runtime = 'nodejs';

// but undici only can be used in NodeJS
// so if you want to use with proxy, you need comment the code below
export const runtime = 'edge';

export const preferredRegion = [
Expand Down
218 changes: 216 additions & 2 deletions src/libs/agent-runtime/google/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
// @vitest-environment edge-runtime
import { GenerateContentRequest, GenerateContentStreamResult, Part } from '@google/generative-ai';
import Dexie from 'dexie';
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamCallbacks } from '@/libs/agent-runtime';
import { ChatStreamCallbacks, OpenAIChatMessage } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeGoogleAI } from './index';
Expand Down Expand Up @@ -53,8 +52,162 @@ describe('LobeGoogleAI', () => {
// Assert
expect(result).toBeInstanceOf(Response);
});
it('should handle text messages correctly', async () => {
// 模拟 Google AI SDK 的 generateContentStream 方法返回一个成功的响应流
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValueOnce(mockStream),
} as any);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

expect(result).toBeInstanceOf(Response);
// 额外的断言可以加入,比如验证返回的流内容等
});

it('should call debugStream in DEBUG mode', async () => {
// 设置环境变量以启用DEBUG模式
process.env.DEBUG_GOOGLE_CHAT_COMPLETION = '1';

// 模拟 Google AI SDK 的 generateContentStream 方法返回一个成功的响应流
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Debug mode test');
controller.close();
},
});
vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockResolvedValueOnce(mockStream),
} as any);
const debugStreamSpy = vi
.spyOn(debugStreamModule, 'debugStream')
.mockImplementation(() => Promise.resolve());

await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

expect(debugStreamSpy).toHaveBeenCalled();

// 清理环境变量
delete process.env.DEBUG_GOOGLE_CHAT_COMPLETION;
});

describe('Error', () => {
it('should throw InvalidGoogleAPIKey error on API_KEY_INVALID error', async () => {
// 模拟 Google AI SDK 抛出异常
const message = `[GoogleGenerativeAI Error]: Error fetching from https://generativelanguage.googleapis.com/v1/models/gemini-pro:streamGenerateContent?alt=sse: [400 Bad Request] API key not valid. Please pass a valid API key. [{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"API_KEY_INVALID","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com"}}]`;

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType, error: { message }, provider });
}
});

it('should throw LocationNotSupportError error on location not support error', async () => {
// 模拟 Google AI SDK 抛出异常
const message = `[GoogleGenerativeAI Error]: Error fetching from https://generativelanguage.googleapis.com/v1/models/gemini-pro:streamGenerateContent?alt=sse: [400 Bad Request] User location is not supported for the API use.`;

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({ errorType: 'LocationNotSupportError', error: { message }, provider });
}
});

it('should throw BizError error', async () => {
// 模拟 Google AI SDK 抛出异常
const message = `[GoogleGenerativeAI Error]: Error fetching from https://generativelanguage.googleapis.com/v1/models/gemini-pro:streamGenerateContent?alt=sse: [400 Bad Request] API key not valid. Please pass a valid API key. [{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"Error","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com"}}]`;

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
errorType: bizErrorType,
error: [
{
'@type': 'type.googleapis.com/google.rpc.ErrorInfo',
'domain': 'googleapis.com',
'metadata': {
service: 'generativelanguage.googleapis.com',
},
'reason': 'Error',
},
],
provider,
});
}
});

it('should throw DefaultError error', async () => {
// 模拟 Google AI SDK 抛出异常
const message = `[GoogleGenerativeAI Error]: Error fetching from https://generativelanguage.googleapis.com/v1/models/gemini-pro:streamGenerateContent?alt=sse: [400 Bad Request] API key not valid. Please pass a valid API key. [{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"Error","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com}}]`;

const apiError = new Error(message);

vi.spyOn(instance['client'], 'getGenerativeModel').mockReturnValue({
generateContentStream: vi.fn().mockRejectedValue(apiError),
} as any);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
errorType: bizErrorType,
error: {
message: `[GoogleGenerativeAI Error]: Error fetching from https://generativelanguage.googleapis.com/v1/models/gemini-pro:streamGenerateContent?alt=sse: [400 Bad Request] API key not valid. Please pass a valid API key. [{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"Error","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com}}]`,
},
provider,
});
}
});

it('should return GoogleBizError with an openai error response when APIError is thrown', async () => {
// Arrange
const apiError = new Error('Error message');
Expand Down Expand Up @@ -147,4 +300,65 @@ describe('LobeGoogleAI', () => {
});
});
});

describe('private method', () => {
describe('convertContentToGooglePart', () => {
it('should throw TypeError when image URL does not contain base64 data', () => {
// 提供一个不包含base64数据的图像URL
const invalidImageUrl = 'http://example.com/image.png';

expect(() =>
instance['convertContentToGooglePart']({
type: 'image_url',
image_url: { url: invalidImageUrl },
}),
).toThrow(TypeError);
});
});

describe('buildGoogleMessages', () => {
it('should use default text model when no images are included in messages', () => {
const messages: OpenAIChatMessage[] = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi', role: 'assistant' },
];
const model = 'text-davinci-003';

// 调用 buildGoogleMessages 方法
const { contents, model: usedModel } = instance['buildGoogleMessages'](messages, model);

expect(usedModel).toEqual('gemini-pro'); // 假设 'gemini-pro' 是默认文本模型
expect(contents).toHaveLength(2);
expect(contents).toEqual([
{ parts: [{ text: 'Hello' }], role: 'user' },
{ parts: [{ text: 'Hi' }], role: 'model' },
]);
});

it('should use specified model when images are included in messages', () => {
const messages: OpenAIChatMessage[] = [
{
content: [
{ type: 'text', text: 'Hello' },
{ type: 'image_url', image_url: { url: 'data:image/png;base64,...' } },
],
role: 'user',
},
];
const model = 'gemini-pro-vision';

// 调用 buildGoogleMessages 方法
const { contents, model: usedModel } = instance['buildGoogleMessages'](messages, model);

expect(usedModel).toEqual(model);
expect(contents).toHaveLength(1);
expect(contents).toEqual([
{
parts: [{ text: 'Hello' }, { inlineData: { data: '...', mimeType: 'image/png' } }],
role: 'user',
},
]);
});
});
});
});
3 changes: 2 additions & 1 deletion src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ export class LobeGoogleAI implements LobeRuntimeAI {
error: { message },
errorType: AgentRuntimeErrorType.GoogleBizError,
};
console.log(message);

if (message.includes('location is not supported'))
return { error: message, errorType: AgentRuntimeErrorType.LocationNotSupportError };
return { error: { message }, errorType: AgentRuntimeErrorType.LocationNotSupportError };

try {
const startIndex = message.lastIndexOf('[');
Expand Down
Loading