Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 6 additions & 36 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {
ContentType,
Message,
Params,
Tool,
ToolCall,
SYSTEM_MESSAGE_ROLES,
MESSAGE_ROLES,
Expand Down Expand Up @@ -46,27 +45,16 @@ import type {
GoogleGenerateContentResponse,
VertexLlamaChatCompleteStreamChunk,
VertexLLamaChatCompleteResponse,
GoogleSearchRetrievalTool,
} from './types';
import {
getMimeType,
googleTools,
recursivelyDeleteUnsupportedParameters,
transformGeminiToolParameters,
transformGoogleTools,
transformInputAudioPart,
transformVertexLogprobs,
} from './utils';

export const buildGoogleSearchRetrievalTool = (tool: Tool) => {
const googleSearchRetrievalTool: GoogleSearchRetrievalTool = {
googleSearchRetrieval: {},
};
if (tool.function.parameters?.dynamicRetrievalConfig) {
googleSearchRetrievalTool.googleSearchRetrieval.dynamicRetrievalConfig =
tool.function.parameters.dynamicRetrievalConfig;
}
return googleSearchRetrievalTool;
};

export const VertexGoogleChatCompleteConfig: ProviderConfig = {
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions
model: {
Expand Down Expand Up @@ -296,27 +284,9 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
// these are not supported by google
recursivelyDeleteUnsupportedParameters(tool.function?.parameters);
delete tool.function?.strict;

if (['googleSearch', 'google_search'].includes(tool.function.name)) {
const timeRangeFilter = tool.function.parameters?.timeRangeFilter;
tools.push({
googleSearch: {
// allow null
...(timeRangeFilter !== undefined && { timeRangeFilter }),
},
});
} else if (
['googleSearchRetrieval', 'google_search_retrieval'].includes(
tool.function.name
)
) {
tools.push(buildGoogleSearchRetrievalTool(tool));
if (googleTools.includes(tool.function.name)) {
tools.push(...transformGoogleTools(tool));
} else {
if (tool.function?.parameters) {
tool.function.parameters = transformGeminiToolParameters(
tool.function.parameters
);
}
functionDeclarations.push(tool.function);
}
}
Expand Down Expand Up @@ -359,11 +329,11 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
seed: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing seed?

modalities: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
modalities: {
seed: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
Expand Down
51 changes: 50 additions & 1 deletion src/providers/google-vertex-ai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
GoogleResponseCandidate,
GoogleBatchRecord,
GoogleFinetuneRecord,
GoogleSearchRetrievalTool,
} from './types';
import { generateErrorResponse } from '../utils';
import {
Expand All @@ -13,7 +14,7 @@ import {
import { ErrorResponse, FinetuneRequest, Logprobs } from '../types';
import { Context } from 'hono';
import { env } from 'hono/adapter';
import { ContentType, JsonSchema } from '../../types/requestBody';
import { ContentType, JsonSchema, Tool } from '../../types/requestBody';

/**
* Encodes an object as a Base64 URL-encoded string.
Expand Down Expand Up @@ -729,3 +730,51 @@ export const transformInputAudioPart = (c: ContentType) => {
},
};
};

export const googleTools = [
'googleSearch',
'google_search',
'googleSearchRetrieval',
'google_search_retrieval',
'computerUse',
'computer_use',
];

export const transformGoogleTools = (tool: Tool) => {
const tools: any = [];
if (['googleSearch', 'google_search'].includes(tool.function.name)) {
const timeRangeFilter = tool.function.parameters?.timeRangeFilter;
tools.push({
googleSearch: {
// allow null
...(timeRangeFilter !== undefined && { timeRangeFilter }),
},
});
} else if (
['googleSearchRetrieval', 'google_search_retrieval'].includes(
tool.function.name
)
) {
tools.push(buildGoogleSearchRetrievalTool(tool));
} else if (['computerUse', 'computer_use'].includes(tool.function.name)) {
tools.push({
computerUse: {
environment: tool.function.parameters?.environment,
excludedPredefinedFunctions:
tool.function.parameters?.excluded_predefined_functions,
},
});
}
return tools;
};

export const buildGoogleSearchRetrievalTool = (tool: Tool) => {
const googleSearchRetrievalTool: GoogleSearchRetrievalTool = {
googleSearchRetrieval: {},
};
if (tool.function.parameters?.dynamicRetrievalConfig) {
googleSearchRetrievalTool.googleSearchRetrieval.dynamicRetrievalConfig =
tool.function.parameters.dynamicRetrievalConfig;
}
return googleSearchRetrievalTool;
};
14 changes: 4 additions & 10 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import {
SYSTEM_MESSAGE_ROLES,
MESSAGE_ROLES,
} from '../../types/requestBody';
import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete';
import {
getMimeType,
googleTools,
recursivelyDeleteUnsupportedParameters,
transformGeminiToolParameters,
transformGoogleTools,
transformInputAudioPart,
transformVertexLogprobs,
} from '../google-vertex-ai/utils';
Expand Down Expand Up @@ -374,15 +375,8 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
// these are not supported by google
recursivelyDeleteUnsupportedParameters(tool.function?.parameters);
delete tool.function?.strict;

if (['googleSearch', 'google_search'].includes(tool.function.name)) {
tools.push({ googleSearch: {} });
} else if (
['googleSearchRetrieval', 'google_search_retrieval'].includes(
tool.function.name
)
) {
tools.push(buildGoogleSearchRetrievalTool(tool));
if (googleTools.includes(tool.function.name)) {
tools.push(...transformGoogleTools(tool));
} else {
if (tool.function?.parameters) {
tool.function.parameters = transformGeminiToolParameters(
Expand Down