diff --git a/package.json b/package.json index 866825a53460..5e790183ca0f 100644 --- a/package.json +++ b/package.json @@ -84,7 +84,7 @@ "@aws-sdk/client-bedrock-runtime": "^3.525.0", "@azure/openai": "^1.0.0-beta.11", "@cfworker/json-schema": "^1", - "@google/generative-ai": "^0.2.0", + "@google/generative-ai": "^0.3.1", "@icons-pack/react-simple-icons": "^9", "@lobehub/chat-plugin-sdk": "latest", "@lobehub/chat-plugins-gateway": "latest", diff --git a/src/app/api/chat/google/route.ts b/src/app/api/chat/google/route.ts index 972bfb875eaf..6b75b94ff85b 100644 --- a/src/app/api/chat/google/route.ts +++ b/src/app/api/chat/google/route.ts @@ -13,19 +13,7 @@ import { POST as UniverseRoute } from '../[provider]/route'; // so if you want to use with proxy, you need comment the code below export const runtime = 'edge'; -export const preferredRegion = [ - 'bom1', - 'cle1', - 'cpt1', - 'gru1', - 'hnd1', - 'iad1', - 'icn1', - 'kix1', - 'pdx1', - 'sfo1', - 'sin1', - 'syd1', -]; +// due to gemini-1.5-pro only can be used in us, so we need to set the preferred region only in US +export const preferredRegion = ['cle1', 'iad1', 'pdx1', 'sfo1']; export const POST = async (req: Request) => UniverseRoute(req, { params: { provider: 'google' } }); diff --git a/src/config/modelProviders/google.ts b/src/config/modelProviders/google.ts index b370066b56e9..8767bd9d28b7 100644 --- a/src/config/modelProviders/google.ts +++ b/src/config/modelProviders/google.ts @@ -3,25 +3,86 @@ import { ModelProviderCard } from '@/types/llm'; const Google: ModelProviderCard = { chatModels: [ { - displayName: 'Gemini Pro', + description: 'A legacy text-only model optimized for chat conversations', + displayName: 'PaLM 2 Chat (Legacy)', + hidden: true, + id: 'chat-bison-001', + maxOutput: 1024, + tokens: 5120, + }, + { + description: 'A legacy model that understands text and generates text as an output', + displayName: 'PaLM 2 (Legacy)', + hidden: true, + id: 'text-bison-001', + maxOutput: 1024, + tokens: 9220, + }, + { + description: 'The best model for scaling across a wide range of tasks', + displayName: 'Gemini 1.0 Pro', id: 'gemini-pro', - tokens: 30_720, + maxOutput: 2048, + tokens: 32_768, + }, + { + description: 'The best image understanding model to handle a broad range of applications', + displayName: 'Gemini 1.0 Pro Vision', + id: 'gemini-1.0-pro-vision-latest', + maxOutput: 4096, + tokens: 16_384, + vision: true, }, { - displayName: 'Gemini Pro Vision', + description: 'The best image understanding model to handle a broad range of applications', + displayName: 'Gemini 1.0 Pro Vision', + hidden: true, id: 'gemini-pro-vision', - tokens: 12_288, + maxOutput: 4096, + tokens: 16_384, vision: true, }, { + description: 'The best model for scaling across a wide range of tasks', + displayName: 'Gemini 1.0 Pro', + hidden: true, + id: '1.0-pro', + maxOutput: 2048, + tokens: 32_768, + }, + { + description: + 'The best model for scaling across a wide range of tasks. This is a stable model that supports tuning.', + displayName: 'Gemini 1.0 Pro 001 (Tuning)', + hidden: true, + id: 'gemini-1.0-pro-001', + maxOutput: 2048, + tokens: 32_768, + }, + { + description: + 'The best model for scaling across a wide range of tasks. This is the latest model.', + displayName: 'Gemini 1.0 Pro Latest', + hidden: true, + id: 'gemini-1.0-pro-latest', + maxOutput: 2048, + tokens: 32_768, + }, + { + description: 'Mid-size multimodal model that supports up to 1 million tokens', displayName: 'Gemini 1.5 Pro', id: 'gemini-1.5-pro-latest', - tokens: 1_048_576, + maxOutput: 8192, + tokens: 1_056_768, + vision: true, }, { - displayName: 'Gemini Ultra', + description: 'The most capable model for highly complex tasks', + displayName: 'Gemini 1.0 Ultra', + hidden: true, id: 'gemini-ultra-latest', - tokens: 30_720, + maxOutput: 2048, + tokens: 32_768, }, ], id: 'google', diff --git a/src/libs/agent-runtime/google/index.test.ts b/src/libs/agent-runtime/google/index.test.ts index af6fd6afcaa3..323738480730 100644 --- a/src/libs/agent-runtime/google/index.test.ts +++ b/src/libs/agent-runtime/google/index.test.ts @@ -1,5 +1,4 @@ // @vitest-environment edge-runtime -import { GenerateContentRequest, GenerateContentStreamResult, Part } from '@google/generative-ai'; import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; @@ -317,17 +316,55 @@ describe('LobeGoogleAI', () => { }); describe('buildGoogleMessages', () => { - it('should use default text model when no images are included in messages', () => { + it('get default result with gemini-pro', () => { + const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }]; + + const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + + expect(contents).toHaveLength(1); + expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]); + }); + + it('messages should end with user if using gemini-pro', () => { const messages: OpenAIChatMessage[] = [ { content: 'Hello', role: 'user' }, { content: 'Hi', role: 'assistant' }, ]; - const model = 'text-davinci-003'; - // 调用 buildGoogleMessages 方法 - const { contents, model: usedModel } = instance['buildGoogleMessages'](messages, model); + const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + + expect(contents).toHaveLength(3); + expect(contents).toEqual([ + { parts: [{ text: 'Hello' }], role: 'user' }, + { parts: [{ text: 'Hi' }], role: 'model' }, + { parts: [{ text: '' }], role: 'user' }, + ]); + }); + + it('should include system role if there is a system role prompt', () => { + const messages: OpenAIChatMessage[] = [ + { content: 'you are ChatGPT', role: 'system' }, + { content: 'Who are you', role: 'user' }, + ]; + + const contents = instance['buildGoogleMessages'](messages, 'gemini-pro'); + + expect(contents).toHaveLength(3); + expect(contents).toEqual([ + { parts: [{ text: 'you are ChatGPT' }], role: 'user' }, + { parts: [{ text: '' }], role: 'model' }, + { parts: [{ text: 'Who are you' }], role: 'user' }, + ]); + }); + + it('should not modify the length if model is gemini-1.5-pro', () => { + const messages: OpenAIChatMessage[] = [ + { content: 'Hello', role: 'user' }, + { content: 'Hi', role: 'assistant' }, + ]; + + const contents = instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest'); - expect(usedModel).toEqual('gemini-pro'); // 假设 'gemini-pro' 是默认文本模型 expect(contents).toHaveLength(2); expect(contents).toEqual([ { parts: [{ text: 'Hello' }], role: 'user' }, @@ -348,9 +385,8 @@ describe('LobeGoogleAI', () => { const model = 'gemini-pro-vision'; // 调用 buildGoogleMessages 方法 - const { contents, model: usedModel } = instance['buildGoogleMessages'](messages, model); + const contents = instance['buildGoogleMessages'](messages, model); - expect(usedModel).toEqual(model); expect(contents).toHaveLength(1); expect(contents).toEqual([ { @@ -360,5 +396,35 @@ describe('LobeGoogleAI', () => { ]); }); }); + + describe('convertModel', () => { + it('should use default text model when no images are included in messages', () => { + const messages: OpenAIChatMessage[] = [ + { content: 'Hello', role: 'user' }, + { content: 'Hi', role: 'assistant' }, + ]; + + // 调用 buildGoogleMessages 方法 + const model = instance['convertModel']('gemini-pro-vision', messages); + + expect(model).toEqual('gemini-pro'); // 假设 'gemini-pro' 是默认文本模型 + }); + + it('should use specified model when images are included in messages', () => { + const messages: OpenAIChatMessage[] = [ + { + content: [ + { type: 'text', text: 'Hello' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,...' } }, + ], + role: 'user', + }, + ]; + + const model = instance['convertModel']('gemini-pro-vision', messages); + + expect(model).toEqual('gemini-pro-vision'); + }); + }); }); }); diff --git a/src/libs/agent-runtime/google/index.ts b/src/libs/agent-runtime/google/index.ts index 5c50ca4bf4c8..e503663705f7 100644 --- a/src/libs/agent-runtime/google/index.ts +++ b/src/libs/agent-runtime/google/index.ts @@ -14,17 +14,6 @@ import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { parseDataUri } from '../utils/uriParser'; -type GoogleChatErrors = GoogleChatError[]; - -interface GoogleChatError { - '@type': string; - 'domain': string; - 'metadata': { - service: string; - }; - 'reason': string; -} - enum HarmCategory { HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT', HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT', @@ -47,34 +36,42 @@ export class LobeGoogleAI implements LobeRuntimeAI { async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { try { - const { contents, model } = this.buildGoogleMessages(payload.messages, payload.model); + const model = this.convertModel(payload.model, payload.messages); + + const contents = this.buildGoogleMessages(payload.messages, model); + const geminiStream = await this.client - .getGenerativeModel({ - generationConfig: { - maxOutputTokens: payload.max_tokens, - temperature: payload.temperature, - topP: payload.top_p, - }, - model, - safetySettings: [ - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold: HarmBlockThreshold.BLOCK_NONE, - }, - { - category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold: HarmBlockThreshold.BLOCK_NONE, + .getGenerativeModel( + { + generationConfig: { + maxOutputTokens: payload.max_tokens, + temperature: payload.temperature, + topP: payload.top_p, }, - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold: HarmBlockThreshold.BLOCK_NONE, - }, - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold: HarmBlockThreshold.BLOCK_NONE, - }, - ], - }) + model, + // avoid wide sensitive words + // refs: https://github.com/lobehub/lobe-chat/pull/1418 + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + ], + }, + { apiVersion: 'v1beta' }, + ) .generateContentStream({ contents }); // Convert the response into a friendly text-stream @@ -127,25 +124,64 @@ export class LobeGoogleAI implements LobeRuntimeAI { typeof content === 'string' ? [{ text: content }] : content.map((c) => this.convertContentToGooglePart(c)), - role: message.role === 'user' ? 'user' : 'model', + role: message.role === 'assistant' ? 'model' : 'user', }; }; // convert messages from the Vercel AI SDK Format to the format // that is expected by the Google GenAI SDK - private buildGoogleMessages = ( - messages: OpenAIChatMessage[], - model: string, - ): { contents: Content[]; model: string } => { - const contents = messages - .filter((message) => message.role === 'user' || message.role === 'assistant') - .map((msg) => this.convertOAIMessagesToGoogleMessage(msg)); - - // if message are all text message, use vision will return error - // use add an image to use models/gemini-pro-vision, or switch your model to a text model - const noImage = messages.every((m) => typeof m.content === 'string'); - - return { contents, model: noImage ? 'gemini-pro' : model }; + private buildGoogleMessages = (messages: OpenAIChatMessage[], model: string): Content[] => { + // if the model is gemini-1.5-pro-latest, we don't need any special handling + if (model === 'gemini-1.5-pro-latest') { + return messages + .filter((message) => message.role !== 'function') + .map((msg) => this.convertOAIMessagesToGoogleMessage(msg)); + } + + const contents: Content[] = []; + let lastRole = 'model'; + + messages.forEach((message) => { + // current to filter function message + if (message.role === 'function') { + return; + } + const googleMessage = this.convertOAIMessagesToGoogleMessage(message); + + // if the last message is a model message and the current message is a model message + // then we need to add a user message to separate them + if (lastRole === googleMessage.role) { + contents.push({ parts: [{ text: '' }], role: lastRole === 'user' ? 'model' : 'user' }); + } + + // add the current message to the contents + contents.push(googleMessage); + + // update the last role + lastRole = googleMessage.role; + }); + + // if the last message is a user message, then we need to add a model message to separate them + if (lastRole === 'model') { + contents.push({ parts: [{ text: '' }], role: 'user' }); + } + + return contents; + }; + + private convertModel = (model: string, messages: OpenAIChatMessage[]) => { + let finalModel: string = model; + + if (model.includes('pro-vision')) { + // if message are all text message, use vision will return an error: + // "[400 Bad Request] Add an image to use models/gemini-pro-vision, or switch your model to a text model." + const noNeedVision = messages.every((m) => typeof m.content === 'string'); + + // so we need to downgrade to gemini-pro + if (noNeedVision) finalModel = 'gemini-pro'; + } + + return finalModel; }; private parseErrorMessage(message: string): { @@ -191,3 +227,14 @@ export class LobeGoogleAI implements LobeRuntimeAI { } export default LobeGoogleAI; + +type GoogleChatErrors = GoogleChatError[]; + +interface GoogleChatError { + '@type': string; + 'domain': string; + 'metadata': { + service: string; + }; + 'reason': string; +} diff --git a/src/types/llm.ts b/src/types/llm.ts index f7583f36a136..3d8ff2d19da0 100644 --- a/src/types/llm.ts +++ b/src/types/llm.ts @@ -20,6 +20,9 @@ export interface ChatModelCard { */ legacy?: boolean; maxOutput?: number; + /** + * the context window + */ tokens?: number; /** * whether model supports vision