From ed35a1345eff3e25678fb837ab310c88efa50adf Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 25 Jul 2024 15:42:53 -0700 Subject: [PATCH] google-common[minor]: Add tool choice param (#6195) * google-common[minor]: Add tool choice param * format * drop console logs * chore: lint files * chore: lint files * cr * cr * fix tool calling issues * chore: lint files --- .../src/chat_models.ts | 52 +----- .../langchain-google-common/src/connection.ts | 19 +++ .../src/tests/chat_models.test.ts | 46 +---- libs/langchain-google-common/src/types.ts | 25 +++ .../src/utils/common.ts | 160 +++++++++++------- .../src/tests/chat_models.int.test.ts | 37 +++- 6 files changed, 191 insertions(+), 148 deletions(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index ca63052bacbe..7f3146fb2b20 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -13,7 +13,6 @@ import { BaseLanguageModelInput, StructuredOutputMethodOptions, ToolDefinition, - isOpenAITool, } from "@langchain/core/language_models/base"; import type { z } from "zod"; import { @@ -24,7 +23,6 @@ import { } from "@langchain/core/runnables"; import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"; import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; -import { isStructuredTool } from "@langchain/core/utils/function_calling"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { StructuredToolInterface } from "@langchain/core/tools"; import { concat } from "@langchain/core/utils/stream"; @@ -39,6 +37,7 @@ import { GoogleAIBaseLanguageModelCallOptions, } from "./types.js"; import { + convertToGeminiTools, copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js"; @@ -59,10 +58,7 @@ import type { GeminiFunctionDeclaration, GeminiFunctionSchema, } from "./types.js"; -import { - jsonSchemaToGeminiParameters, - zodToGeminiParameters, -} from "./utils/zod_to_gemini_parameters.js"; +import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js"; class ChatConnection extends AbstractGoogleLLMConnection< BaseMessage[], @@ -160,44 +156,6 @@ export interface ChatGoogleBaseInput GoogleAISafetyParams, Pick {} -function convertToGeminiTools( - structuredTools: ( - | StructuredToolInterface - | Record - | ToolDefinition - | RunnableToolLike - )[] -): GeminiTool[] { - return [ - { - functionDeclarations: structuredTools.map( - (structuredTool): GeminiFunctionDeclaration => { - if (isStructuredTool(structuredTool)) { - const jsonSchema = zodToGeminiParameters(structuredTool.schema); - return { - name: structuredTool.name, - description: structuredTool.description, - parameters: jsonSchema as GeminiFunctionSchema, - }; - } - if (isOpenAITool(structuredTool)) { - return { - name: structuredTool.function.name, - description: - structuredTool.function.description ?? - `A function available to call.`, - parameters: jsonSchemaToGeminiParameters( - structuredTool.function.parameters - ), - }; - } - return structuredTool as unknown as GeminiFunctionDeclaration; - } - ), - }, - ]; -} - /** * Integration with a chat model. */ @@ -342,12 +300,6 @@ export abstract class ChatGoogleBase * Get the parameters used to invoke the model */ override invocationParams(options?: this["ParsedCallOptions"]) { - if (options?.tool_choice) { - throw new Error( - `'tool_choice' call option is not supported by ${this.getName()}.` - ); - } - return copyAIModelParams(this, options); } diff --git a/libs/langchain-google-common/src/connection.ts b/libs/langchain-google-common/src/connection.ts index 212bfa886b8f..cdc923922b07 100644 --- a/libs/langchain-google-common/src/connection.ts +++ b/libs/langchain-google-common/src/connection.ts @@ -350,6 +350,21 @@ export abstract class AbstractGoogleLLMConnection< } } + formatToolConfig( + parameters: GoogleAIModelRequestParams + ): GeminiRequest["toolConfig"] | undefined { + if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") { + return undefined; + } + + return { + functionCallingConfig: { + mode: parameters.tool_choice as "auto" | "any" | "none", + allowedFunctionNames: parameters.allowed_function_names, + }, + }; + } + formatData( input: MessageType, parameters: GoogleAIModelRequestParams @@ -357,6 +372,7 @@ export abstract class AbstractGoogleLLMConnection< const contents = this.formatContents(input, parameters); const generationConfig = this.formatGenerationConfig(input, parameters); const tools = this.formatTools(input, parameters); + const toolConfig = this.formatToolConfig(parameters); const safetySettings = this.formatSafetySettings(input, parameters); const systemInstruction = this.formatSystemInstruction(input, parameters); @@ -367,6 +383,9 @@ export abstract class AbstractGoogleLLMConnection< if (tools && tools.length) { ret.tools = tools; } + if (toolConfig) { + ret.toolConfig = toolConfig; + } if (safetySettings && safetySettings.length) { ret.safetySettings = safetySettings; } diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index 71f3aa876ffe..dda4b68033ce 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -140,11 +140,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -178,11 +174,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -273,11 +265,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -318,11 +306,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -363,11 +347,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -406,11 +386,7 @@ describe("Mock ChatGoogle", () => { new AIMessage("H"), new HumanMessage("Flip it again"), ]; - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log("record", JSON.stringify(record, null, 1)); - // console.log("result", JSON.stringify(result, null, 1)); + await model.invoke(messages); expect(record.opts).toBeDefined(); expect(record.opts.data).toBeDefined(); @@ -453,10 +429,7 @@ describe("Mock ChatGoogle", () => { let caught = false; try { - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log(result); + await model.invoke(messages); } catch (xx) { caught = true; } @@ -485,10 +458,7 @@ describe("Mock ChatGoogle", () => { let caught = false; try { - // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment - // @ts-expect-error unused var - const result = await model.invoke(messages); - // console.log(result); + await model.invoke(messages); } catch (xx) { caught = true; } diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index 3d316f52ddbe..a07c32e67555 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -117,6 +117,25 @@ export interface GoogleAIModelParams { */ export interface GoogleAIModelRequestParams extends GoogleAIModelParams { tools?: StructuredToolInterface[] | GeminiTool[]; + /** + * Force the model to use tools in a specific way. + * + * | Mode | Description | + * |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------| + * | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. | + * | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. | + * | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. | + * | string | The string value must be one of the function names. This will force the model to predict the specified function call. | + * + * The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + tool_choice?: string | "auto" | "any" | "none" | Record; + /** + * Allowed functions to call when the mode is "any". + * If empty, any one of the provided functions are called. + */ + allowed_function_names?: string[]; } export interface GoogleAIBaseLLMInput @@ -251,6 +270,12 @@ export interface GeminiRequest { contents?: GeminiContent[]; systemInstruction?: GeminiContent; tools?: GeminiTool[]; + toolConfig?: { + functionCallingConfig: { + mode: "auto" | "any" | "none"; + allowedFunctionNames?: string[]; + }; + }; safetySettings?: GeminiSafetySetting[]; generationConfig?: GeminiGenerationConfig; } diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index 6ea6533d8225..17f29185811c 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -1,12 +1,24 @@ import { StructuredToolInterface } from "@langchain/core/tools"; +import { + isOpenAITool, + ToolDefinition, +} from "@langchain/core/language_models/base"; +import { RunnableToolLike } from "@langchain/core/runnables"; +import { isStructuredTool } from "@langchain/core/utils/function_calling"; +import { isModelGemini, validateGeminiParams } from "./gemini.js"; import type { + GeminiFunctionDeclaration, + GeminiFunctionSchema, GeminiTool, GoogleAIBaseLanguageModelCallOptions, GoogleAIModelParams, GoogleAIModelRequestParams, GoogleLLMModelFamily, } from "../types.js"; -import { isModelGemini, validateGeminiParams } from "./gemini.js"; +import { + jsonSchemaToGeminiParameters, + zodToGeminiParameters, +} from "./zod_to_gemini_parameters.js"; export function copyAIModelParams( params: GoogleAIModelParams | undefined, @@ -15,6 +27,82 @@ export function copyAIModelParams( return copyAIModelParamsInto(params, options, {}); } +function processToolChoice( + toolChoice: GoogleAIBaseLanguageModelCallOptions["tool_choice"], + allowedFunctionNames: GoogleAIBaseLanguageModelCallOptions["allowed_function_names"] +): + | { + tool_choice: "any" | "auto" | "none"; + allowed_function_names?: string[]; + } + | undefined { + if (!toolChoice) { + if (allowedFunctionNames) { + // Allowed func names is passed, return 'any' so it forces the model to use a tool. + return { + tool_choice: "any", + allowed_function_names: allowedFunctionNames, + }; + } + return undefined; + } + + if (toolChoice === "any" || toolChoice === "auto" || toolChoice === "none") { + return { + tool_choice: toolChoice, + allowed_function_names: allowedFunctionNames, + }; + } + if (typeof toolChoice === "string") { + // String representing the function name. + // Return any to force the model to predict the specified function call. + return { + tool_choice: "any", + allowed_function_names: [...(allowedFunctionNames ?? []), toolChoice], + }; + } + throw new Error("Object inputs for tool_choice not supported."); +} + +export function convertToGeminiTools( + structuredTools: ( + | StructuredToolInterface + | Record + | ToolDefinition + | RunnableToolLike + )[] +): GeminiTool[] { + const tools: GeminiTool[] = [ + { + functionDeclarations: [], + }, + ]; + structuredTools.forEach((tool) => { + if ( + "functionDeclarations" in tool && + Array.isArray(tool.functionDeclarations) + ) { + const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations; + tools[0].functionDeclarations?.push(...funcs); + } else if (isStructuredTool(tool)) { + const jsonSchema = zodToGeminiParameters(tool.schema); + tools[0].functionDeclarations?.push({ + name: tool.name, + description: tool.description, + parameters: jsonSchema as GeminiFunctionSchema, + }); + } else if (isOpenAITool(tool)) { + tools[0].functionDeclarations?.push({ + name: tool.function.name, + description: + tool.function.description ?? `A function available to call.`, + parameters: jsonSchemaToGeminiParameters(tool.function.parameters), + }); + } + }); + return tools; +} + export function copyAIModelParamsInto( params: GoogleAIModelParams | undefined, options: GoogleAIBaseLanguageModelCallOptions | undefined, @@ -46,66 +134,20 @@ export function copyAIModelParamsInto( params?.responseMimeType ?? target?.responseMimeType; ret.streaming = options?.streaming ?? params?.streaming ?? target?.streaming; + const toolChoice = processToolChoice( + options?.tool_choice, + options?.allowed_function_names + ); + if (toolChoice) { + ret.tool_choice = toolChoice.tool_choice; + ret.allowed_function_names = toolChoice.allowed_function_names; + } - ret.tools = options?.tools; - // Ensure tools are formatted properly for Gemini - const geminiTools = options?.tools - ?.map((tool) => { - if ( - "function" in tool && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - "parameters" in (tool.function as Record) - ) { - // Tool is in OpenAI format. Convert to Gemini then return. - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const castTool = tool.function as Record; - const cleanedParameters = castTool.parameters; - if ("$schema" in cleanedParameters) { - delete cleanedParameters.$schema; - } - if ("additionalProperties" in cleanedParameters) { - delete cleanedParameters.additionalProperties; - } - const toolInGeminiFormat: GeminiTool = { - functionDeclarations: [ - { - name: castTool.name, - description: castTool.description, - parameters: cleanedParameters, - }, - ], - }; - return toolInGeminiFormat; - } else if ("functionDeclarations" in tool) { - return tool; - } else { - return null; - } - }) - .filter((tool): tool is GeminiTool => tool !== null); - - const structuredOutputTools = options?.tools - ?.map((tool) => { - if ("lc_namespace" in tool) { - return tool; - } else { - return null; - } - }) - .filter((tool): tool is StructuredToolInterface => tool !== null); - - if ( - structuredOutputTools && - structuredOutputTools.length > 0 && - geminiTools && - geminiTools.length > 0 - ) { - throw new Error( - `Cannot mix structured tools with Gemini tools.\nReceived ${structuredOutputTools.length} structured tools and ${geminiTools.length} Gemini tools.` - ); + const tools = options?.tools; + if (tools) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ret.tools = convertToGeminiTools(tools as Record[]); } - ret.tools = geminiTools ?? structuredOutputTools; return ret; } diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 2fa428a20924..d6c2d2b77dec 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -12,8 +12,8 @@ import { ToolMessage, } from "@langchain/core/messages"; import { tool } from "@langchain/core/tools"; -import { concat } from "@langchain/core/utils/stream"; import { z } from "zod"; +import { concat } from "@langchain/core/utils/stream"; import { GeminiTool } from "../types.js"; import { ChatVertexAI } from "../chat_models.js"; @@ -282,6 +282,41 @@ test("Streaming true constructor param will stream", async () => { expect(totalTokenCount).toBeGreaterThan(1); }); +test("Can force a model to invoke a tool", async () => { + const model = new ChatVertexAI({ + model: "gemini-1.5-pro", + }); + const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + schema: z.object({ + location: z.string().describe("The name of city to get the weather for."), + }), + }); + const calculatorTool = tool((_) => "no-op", { + name: "calculator", + description: "Calculate the result of a math expression.", + schema: z.object({ + expression: z.string().describe("The math expression to calculate."), + }), + }); + const modelWithTools = model.bind({ + tools: [calculatorTool, weatherTool], + tool_choice: "calculator", + }); + + const result = await modelWithTools.invoke( + "Whats the weather like in paris today? What's 1836 plus 7262?" + ); + + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) return; + expect(result.tool_calls?.[0].name).toBe("calculator"); + expect(result.tool_calls?.[0].args).toHaveProperty("expression"); +}); + test("ChatGoogleGenerativeAI can stream tools", async () => { const model = new ChatVertexAI({});