Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-common[minor]: Add tool choice param #6195

Merged
merged 12 commits into from
Jul 25, 2024
52 changes: 2 additions & 50 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
isOpenAITool,
} from "@langchain/core/language_models/base";
import type { z } from "zod";
import {
Expand All @@ -24,7 +23,6 @@ import {
} from "@langchain/core/runnables";
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools";
import { BaseLLMOutputParser } from "@langchain/core/output_parsers";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { StructuredToolInterface } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
Expand All @@ -39,6 +37,7 @@ import {
GoogleAIBaseLanguageModelCallOptions,
} from "./types.js";
import {
convertToGeminiTools,
copyAIModelParams,
copyAndValidateModelParamsInto,
} from "./utils/common.js";
Expand All @@ -59,10 +58,7 @@ import type {
GeminiFunctionDeclaration,
GeminiFunctionSchema,
} from "./types.js";
import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./utils/zod_to_gemini_parameters.js";
import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js";

class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
BaseMessage[],
Expand Down Expand Up @@ -160,44 +156,6 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[]
): GeminiTool[] {
return [
{
functionDeclarations: structuredTools.map(
(structuredTool): GeminiFunctionDeclaration => {
if (isStructuredTool(structuredTool)) {
const jsonSchema = zodToGeminiParameters(structuredTool.schema);
return {
name: structuredTool.name,
description: structuredTool.description,
parameters: jsonSchema as GeminiFunctionSchema,
};
}
if (isOpenAITool(structuredTool)) {
return {
name: structuredTool.function.name,
description:
structuredTool.function.description ??
`A function available to call.`,
parameters: jsonSchemaToGeminiParameters(
structuredTool.function.parameters
),
};
}
return structuredTool as unknown as GeminiFunctionDeclaration;
}
),
},
];
}

/**
* Integration with a chat model.
*/
Expand Down Expand Up @@ -342,12 +300,6 @@ export abstract class ChatGoogleBase<AuthOptions>
* Get the parameters used to invoke the model
*/
override invocationParams(options?: this["ParsedCallOptions"]) {
if (options?.tool_choice) {
throw new Error(
`'tool_choice' call option is not supported by ${this.getName()}.`
);
}

return copyAIModelParams(this, options);
}

Expand Down
19 changes: 19 additions & 0 deletions libs/langchain-google-common/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,29 @@ export abstract class AbstractGoogleLLMConnection<
}
}

formatToolConfig(
parameters: GoogleAIModelRequestParams
): GeminiRequest["toolConfig"] | undefined {
if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
return undefined;
}

return {
functionCallingConfig: {
mode: parameters.tool_choice as "auto" | "any" | "none",
allowedFunctionNames: parameters.allowed_function_names,
},
};
}

formatData(
input: MessageType,
parameters: GoogleAIModelRequestParams
): GeminiRequest {
const contents = this.formatContents(input, parameters);
const generationConfig = this.formatGenerationConfig(input, parameters);
const tools = this.formatTools(input, parameters);
const toolConfig = this.formatToolConfig(parameters);
const safetySettings = this.formatSafetySettings(input, parameters);
const systemInstruction = this.formatSystemInstruction(input, parameters);

Expand All @@ -367,6 +383,9 @@ export abstract class AbstractGoogleLLMConnection<
if (tools && tools.length) {
ret.tools = tools;
}
if (toolConfig) {
ret.toolConfig = toolConfig;
}
if (safetySettings && safetySettings.length) {
ret.safetySettings = safetySettings;
}
Expand Down
46 changes: 8 additions & 38 deletions libs/langchain-google-common/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -178,11 +174,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -273,11 +265,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -318,11 +306,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -363,11 +347,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -406,11 +386,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -453,10 +429,7 @@ describe("Mock ChatGoogle", () => {

let caught = false;
try {
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log(result);
await model.invoke(messages);
} catch (xx) {
caught = true;
}
Expand Down Expand Up @@ -485,10 +458,7 @@ describe("Mock ChatGoogle", () => {

let caught = false;
try {
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log(result);
await model.invoke(messages);
} catch (xx) {
caught = true;
}
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ export interface GoogleAIModelParams {
*/
export interface GoogleAIModelRequestParams extends GoogleAIModelParams {
tools?: StructuredToolInterface[] | GeminiTool[];
/**
* Force the model to use tools in a specific way.
*
* | Mode | Description |
* |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
* | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. |
* | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. |
* | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. |
* | string | The string value must be one of the function names. This will force the model to predict the specified function call. |
*
* The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool_choice?: string | "auto" | "any" | "none" | Record<string, any>;
/**
* Allowed functions to call when the mode is "any".
* If empty, any one of the provided functions are called.
*/
allowed_function_names?: string[];
}

export interface GoogleAIBaseLLMInput<AuthOptions>
Expand Down Expand Up @@ -251,6 +270,12 @@ export interface GeminiRequest {
contents?: GeminiContent[];
systemInstruction?: GeminiContent;
tools?: GeminiTool[];
toolConfig?: {
functionCallingConfig: {
mode: "auto" | "any" | "none";
allowedFunctionNames?: string[];
};
};
safetySettings?: GeminiSafetySetting[];
generationConfig?: GeminiGenerationConfig;
}
Expand Down
Loading
Loading