Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-common[minor]: Add tool choice param #6195

Merged
merged 12 commits into from
Jul 25, 2024
52 changes: 2 additions & 50 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
isOpenAITool,
} from "@langchain/core/language_models/base";
import type { z } from "zod";
import {
Expand All @@ -24,7 +23,6 @@ import {
} from "@langchain/core/runnables";
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools";
import { BaseLLMOutputParser } from "@langchain/core/output_parsers";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { StructuredToolInterface } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
Expand All @@ -39,6 +37,7 @@ import {
GoogleAIBaseLanguageModelCallOptions,
} from "./types.js";
import {
convertToGeminiTools,
copyAIModelParams,
copyAndValidateModelParamsInto,
} from "./utils/common.js";
Expand All @@ -59,10 +58,7 @@ import type {
GeminiFunctionDeclaration,
GeminiFunctionSchema,
} from "./types.js";
import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./utils/zod_to_gemini_parameters.js";
import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js";

class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
BaseMessage[],
Expand Down Expand Up @@ -160,44 +156,6 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[]
): GeminiTool[] {
return [
{
functionDeclarations: structuredTools.map(
(structuredTool): GeminiFunctionDeclaration => {
if (isStructuredTool(structuredTool)) {
const jsonSchema = zodToGeminiParameters(structuredTool.schema);
return {
name: structuredTool.name,
description: structuredTool.description,
parameters: jsonSchema as GeminiFunctionSchema,
};
}
if (isOpenAITool(structuredTool)) {
return {
name: structuredTool.function.name,
description:
structuredTool.function.description ??
`A function available to call.`,
parameters: jsonSchemaToGeminiParameters(
structuredTool.function.parameters
),
};
}
return structuredTool as unknown as GeminiFunctionDeclaration;
}
),
},
];
}

/**
* Integration with a chat model.
*/
Expand Down Expand Up @@ -342,12 +300,6 @@ export abstract class ChatGoogleBase<AuthOptions>
* Get the parameters used to invoke the model
*/
override invocationParams(options?: this["ParsedCallOptions"]) {
if (options?.tool_choice) {
throw new Error(
`'tool_choice' call option is not supported by ${this.getName()}.`
);
}

return copyAIModelParams(this, options);
}

Expand Down
19 changes: 19 additions & 0 deletions libs/langchain-google-common/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,29 @@ export abstract class AbstractGoogleLLMConnection<
}
}

formatToolConfig(
parameters: GoogleAIModelRequestParams
): GeminiRequest["toolConfig"] | undefined {
if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
return undefined;
}

return {
functionCallingConfig: {
mode: parameters.tool_choice as "auto" | "any" | "none",
allowedFunctionNames: parameters.allowed_function_names,
},
};
}

formatData(
input: MessageType,
parameters: GoogleAIModelRequestParams
): GeminiRequest {
const contents = this.formatContents(input, parameters);
const generationConfig = this.formatGenerationConfig(input, parameters);
const tools = this.formatTools(input, parameters);
const toolConfig = this.formatToolConfig(parameters);
const safetySettings = this.formatSafetySettings(input, parameters);
const systemInstruction = this.formatSystemInstruction(input, parameters);

Expand All @@ -367,6 +383,9 @@ export abstract class AbstractGoogleLLMConnection<
if (tools && tools.length) {
ret.tools = tools;
}
if (toolConfig) {
ret.toolConfig = toolConfig;
}
if (safetySettings && safetySettings.length) {
ret.safetySettings = safetySettings;
}
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ export interface GoogleAIModelParams {
*/
export interface GoogleAIModelRequestParams extends GoogleAIModelParams {
tools?: StructuredToolInterface[] | GeminiTool[];
/**
* Force the model to use tools in a specific way.
*
* | Mode | Description |
* |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
* | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. |
* | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. |
* | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. |
* | string | The string value must be one of the function names. This will force the model to predict the specified function call. |
*
* The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool_choice?: string | "auto" | "any" | "none" | Record<string, any>;
/**
* Allowed functions to call when the mode is "any".
* If empty, any one of the provided functions are called.
*/
allowed_function_names?: string[];
}

export interface GoogleAIBaseLLMInput<AuthOptions>
Expand Down Expand Up @@ -251,6 +270,12 @@ export interface GeminiRequest {
contents?: GeminiContent[];
systemInstruction?: GeminiContent;
tools?: GeminiTool[];
toolConfig?: {
functionCallingConfig: {
mode: "auto" | "any" | "none";
allowedFunctionNames?: string[];
};
};
safetySettings?: GeminiSafetySetting[];
generationConfig?: GeminiGenerationConfig;
}
Expand Down
159 changes: 100 additions & 59 deletions libs/langchain-google-common/src/utils/common.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import { StructuredToolInterface } from "@langchain/core/tools";
import {
isOpenAITool,
ToolDefinition,
} from "@langchain/core/language_models/base";
import { RunnableToolLike } from "@langchain/core/runnables";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { isModelGemini, validateGeminiParams } from "./gemini.js";
import type {
GeminiFunctionDeclaration,
GeminiFunctionSchema,
GeminiTool,
GoogleAIBaseLanguageModelCallOptions,
GoogleAIModelParams,
GoogleAIModelRequestParams,
GoogleLLMModelFamily,
} from "../types.js";
import { isModelGemini, validateGeminiParams } from "./gemini.js";
import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./zod_to_gemini_parameters.js";

export function copyAIModelParams(
params: GoogleAIModelParams | undefined,
Expand All @@ -15,6 +27,81 @@ export function copyAIModelParams(
return copyAIModelParamsInto(params, options, {});
}

function processToolChoice(
toolChoice: GoogleAIBaseLanguageModelCallOptions["tool_choice"],
allowedFunctionNames: GoogleAIBaseLanguageModelCallOptions["allowed_function_names"]
):
| {
tool_choice: "any" | "auto" | "none";
allowed_function_names?: string[];
}
| undefined {
if (!toolChoice) {
if (allowedFunctionNames) {
// Allowed func names is passed, return 'any' so it forces the model to use a tool.
return {
tool_choice: "any",
allowed_function_names: allowedFunctionNames,
};
}
return undefined;
}

if (toolChoice === "any" || toolChoice === "auto" || toolChoice === "none") {
return {
tool_choice: toolChoice,
allowed_function_names: allowedFunctionNames,
};
}
if (typeof toolChoice === "string") {
// String representing the function name.
// Return any to force the model to predict the specified function call.
return {
tool_choice: "any",
allowed_function_names: [...(allowedFunctionNames ?? []), toolChoice],
};
}
throw new Error("Object inputs for tool_choice not supported.");
}

export function convertToGeminiTools(
structuredTools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[]
): GeminiTool[] {
return [
{
functionDeclarations: structuredTools.map(
(structuredTool): GeminiFunctionDeclaration => {
if (isStructuredTool(structuredTool)) {
const jsonSchema = zodToGeminiParameters(structuredTool.schema);
return {
name: structuredTool.name,
description: structuredTool.description,
parameters: jsonSchema as GeminiFunctionSchema,
};
}
if (isOpenAITool(structuredTool)) {
return {
name: structuredTool.function.name,
description:
structuredTool.function.description ??
`A function available to call.`,
parameters: jsonSchemaToGeminiParameters(
structuredTool.function.parameters
),
};
}
return structuredTool as unknown as GeminiFunctionDeclaration;
}
),
},
];
}

export function copyAIModelParamsInto(
params: GoogleAIModelParams | undefined,
options: GoogleAIBaseLanguageModelCallOptions | undefined,
Expand Down Expand Up @@ -46,66 +133,20 @@ export function copyAIModelParamsInto(
params?.responseMimeType ??
target?.responseMimeType;
ret.streaming = options?.streaming ?? params?.streaming ?? target?.streaming;
const toolChoice = processToolChoice(
options?.tool_choice,
options?.allowed_function_names
);
if (toolChoice) {
ret.tool_choice = toolChoice.tool_choice;
ret.allowed_function_names = toolChoice.allowed_function_names;
}

ret.tools = options?.tools;
// Ensure tools are formatted properly for Gemini
const geminiTools = options?.tools
?.map((tool) => {
if (
"function" in tool &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
"parameters" in (tool.function as Record<string, any>)
) {
// Tool is in OpenAI format. Convert to Gemini then return.

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const castTool = tool.function as Record<string, any>;
const cleanedParameters = castTool.parameters;
if ("$schema" in cleanedParameters) {
delete cleanedParameters.$schema;
}
if ("additionalProperties" in cleanedParameters) {
delete cleanedParameters.additionalProperties;
}
const toolInGeminiFormat: GeminiTool = {
functionDeclarations: [
{
name: castTool.name,
description: castTool.description,
parameters: cleanedParameters,
},
],
};
return toolInGeminiFormat;
} else if ("functionDeclarations" in tool) {
return tool;
} else {
return null;
}
})
.filter((tool): tool is GeminiTool => tool !== null);

const structuredOutputTools = options?.tools
?.map((tool) => {
if ("lc_namespace" in tool) {
return tool;
} else {
return null;
}
})
.filter((tool): tool is StructuredToolInterface => tool !== null);

if (
structuredOutputTools &&
structuredOutputTools.length > 0 &&
geminiTools &&
geminiTools.length > 0
) {
throw new Error(
`Cannot mix structured tools with Gemini tools.\nReceived ${structuredOutputTools.length} structured tools and ${geminiTools.length} Gemini tools.`
);
const tools = options?.tools;
if (tools) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ret.tools = convertToGeminiTools(tools as Record<string, any>[]);
}
ret.tools = geminiTools ?? structuredOutputTools;

return ret;
}
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-google-vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"release-it": "^15.10.1",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed that the recent PR added a new dependency "zod" to the package.json file. It seems like this change might impact the project's dependencies, so I'm flagging it for your review. Thanks!

"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
"typescript": "<5.2.0",
"zod": "^3.22.3"
},
"publishConfig": {
"access": "public"
Expand Down
Loading
Loading