Skip to content

Commit

Permalink
google-common[minor]: Add tool choice param (#6195)
Browse files Browse the repository at this point in the history
* google-common[minor]: Add tool choice param

* format

* drop console logs

* chore: lint files

* chore: lint files

* cr

* cr

* fix tool calling issues

* chore: lint files
  • Loading branch information
bracesproul authored Jul 25, 2024
1 parent 44701ae commit ed35a13
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 148 deletions.
52 changes: 2 additions & 50 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
isOpenAITool,
} from "@langchain/core/language_models/base";
import type { z } from "zod";
import {
Expand All @@ -24,7 +23,6 @@ import {
} from "@langchain/core/runnables";
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools";
import { BaseLLMOutputParser } from "@langchain/core/output_parsers";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { StructuredToolInterface } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
Expand All @@ -39,6 +37,7 @@ import {
GoogleAIBaseLanguageModelCallOptions,
} from "./types.js";
import {
convertToGeminiTools,
copyAIModelParams,
copyAndValidateModelParamsInto,
} from "./utils/common.js";
Expand All @@ -59,10 +58,7 @@ import type {
GeminiFunctionDeclaration,
GeminiFunctionSchema,
} from "./types.js";
import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./utils/zod_to_gemini_parameters.js";
import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js";

class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
BaseMessage[],
Expand Down Expand Up @@ -160,44 +156,6 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[]
): GeminiTool[] {
return [
{
functionDeclarations: structuredTools.map(
(structuredTool): GeminiFunctionDeclaration => {
if (isStructuredTool(structuredTool)) {
const jsonSchema = zodToGeminiParameters(structuredTool.schema);
return {
name: structuredTool.name,
description: structuredTool.description,
parameters: jsonSchema as GeminiFunctionSchema,
};
}
if (isOpenAITool(structuredTool)) {
return {
name: structuredTool.function.name,
description:
structuredTool.function.description ??
`A function available to call.`,
parameters: jsonSchemaToGeminiParameters(
structuredTool.function.parameters
),
};
}
return structuredTool as unknown as GeminiFunctionDeclaration;
}
),
},
];
}

/**
* Integration with a chat model.
*/
Expand Down Expand Up @@ -342,12 +300,6 @@ export abstract class ChatGoogleBase<AuthOptions>
* Get the parameters used to invoke the model
*/
override invocationParams(options?: this["ParsedCallOptions"]) {
if (options?.tool_choice) {
throw new Error(
`'tool_choice' call option is not supported by ${this.getName()}.`
);
}

return copyAIModelParams(this, options);
}

Expand Down
19 changes: 19 additions & 0 deletions libs/langchain-google-common/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,29 @@ export abstract class AbstractGoogleLLMConnection<
}
}

formatToolConfig(
parameters: GoogleAIModelRequestParams
): GeminiRequest["toolConfig"] | undefined {
if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
return undefined;
}

return {
functionCallingConfig: {
mode: parameters.tool_choice as "auto" | "any" | "none",
allowedFunctionNames: parameters.allowed_function_names,
},
};
}

formatData(
input: MessageType,
parameters: GoogleAIModelRequestParams
): GeminiRequest {
const contents = this.formatContents(input, parameters);
const generationConfig = this.formatGenerationConfig(input, parameters);
const tools = this.formatTools(input, parameters);
const toolConfig = this.formatToolConfig(parameters);
const safetySettings = this.formatSafetySettings(input, parameters);
const systemInstruction = this.formatSystemInstruction(input, parameters);

Expand All @@ -367,6 +383,9 @@ export abstract class AbstractGoogleLLMConnection<
if (tools && tools.length) {
ret.tools = tools;
}
if (toolConfig) {
ret.toolConfig = toolConfig;
}
if (safetySettings && safetySettings.length) {
ret.safetySettings = safetySettings;
}
Expand Down
46 changes: 8 additions & 38 deletions libs/langchain-google-common/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -178,11 +174,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -273,11 +265,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -318,11 +306,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -363,11 +347,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -406,11 +386,7 @@ describe("Mock ChatGoogle", () => {
new AIMessage("H"),
new HumanMessage("Flip it again"),
];
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log("record", JSON.stringify(record, null, 1));
// console.log("result", JSON.stringify(result, null, 1));
await model.invoke(messages);

expect(record.opts).toBeDefined();
expect(record.opts.data).toBeDefined();
Expand Down Expand Up @@ -453,10 +429,7 @@ describe("Mock ChatGoogle", () => {

let caught = false;
try {
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log(result);
await model.invoke(messages);
} catch (xx) {
caught = true;
}
Expand Down Expand Up @@ -485,10 +458,7 @@ describe("Mock ChatGoogle", () => {

let caught = false;
try {
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const result = await model.invoke(messages);
// console.log(result);
await model.invoke(messages);
} catch (xx) {
caught = true;
}
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ export interface GoogleAIModelParams {
*/
export interface GoogleAIModelRequestParams extends GoogleAIModelParams {
tools?: StructuredToolInterface[] | GeminiTool[];
/**
* Force the model to use tools in a specific way.
*
* | Mode | Description |
* |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
* | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. |
* | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. |
* | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. |
* | string | The string value must be one of the function names. This will force the model to predict the specified function call. |
*
* The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool_choice?: string | "auto" | "any" | "none" | Record<string, any>;
/**
* Allowed functions to call when the mode is "any".
* If empty, any one of the provided functions are called.
*/
allowed_function_names?: string[];
}

export interface GoogleAIBaseLLMInput<AuthOptions>
Expand Down Expand Up @@ -251,6 +270,12 @@ export interface GeminiRequest {
contents?: GeminiContent[];
systemInstruction?: GeminiContent;
tools?: GeminiTool[];
toolConfig?: {
functionCallingConfig: {
mode: "auto" | "any" | "none";
allowedFunctionNames?: string[];
};
};
safetySettings?: GeminiSafetySetting[];
generationConfig?: GeminiGenerationConfig;
}
Expand Down
Loading

0 comments on commit ed35a13

Please sign in to comment.