Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anthropic[patch]: Fix passing streamed tool calls back to anthropic #6199

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 92 additions & 8 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,18 @@ function _makeMessageChunkFromAnthropicEvent(
streamUsage: boolean;
coerceContentToString: boolean;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
}
): {
chunk: AIMessageChunk;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
} | null {
let usageDataCopy = { ...fields.usageData };

Expand Down Expand Up @@ -233,6 +241,10 @@ function _makeMessageChunkFromAnthropicEvent(
additional_kwargs: {},
}),
usageData: usageDataCopy,
toolUse: {
id: data.content_block.id,
name: data.content_block.name,
},
};
} else if (
data.type === "content_block_delta" &&
Expand Down Expand Up @@ -274,6 +286,25 @@ function _makeMessageChunkFromAnthropicEvent(
}),
usageData: usageDataCopy,
};
} else if (data.type === "content_block_stop" && fields.toolUse) {
// Only yield the ID & name when the tool_use block is complete.
// This is so the names & IDs do not get concatenated.
return {
chunk: new AIMessageChunk({
content: fields.coerceContentToString
? ""
: [
{
id: fields.toolUse.id,
name: fields.toolUse.name,
index: data.index,
type: "input_json_delta",
},
],
additional_kwargs: {},
}),
usageData: usageDataCopy,
};
}

return null;
Expand Down Expand Up @@ -424,6 +455,9 @@ export function _convertLangChainToolCallToAnthropic(
}

function _formatContent(content: MessageContent) {
const toolTypes = ["tool_use", "tool_result", "input_json_delta"];
const textTypes = ["text", "text_delta"];

if (typeof content === "string") {
return content;
} else {
Expand All @@ -439,19 +473,40 @@ function _formatContent(content: MessageContent) {
type: "image" as const, // Explicitly setting the type as "image"
source,
};
} else if (contentPart.type === "text") {
} else if (
textTypes.find((t) => t === contentPart.type) &&
"text" in contentPart
) {
// Assuming contentPart is of type MessageContentText here
return {
type: "text" as const, // Explicitly setting the type as "text"
text: contentPart.text,
};
} else if (
contentPart.type === "tool_use" ||
contentPart.type === "tool_result"
) {
} else if (toolTypes.find((t) => t === contentPart.type)) {
const contentPartCopy = { ...contentPart };
if ("index" in contentPartCopy) {
// Anthropic does not support passing the index field here, so we remove it.
delete contentPartCopy.index;
}

if (contentPartCopy.type === "input_json_delta") {
// `input_json_delta` type only represents yielding partial tool inputs
// and is not a valid type for Anthropic messages.
contentPartCopy.type = "tool_use";
}

if ("input" in contentPartCopy) {
// Anthropic tool use inputs should be valid objects, when applicable.
try {
contentPartCopy.input = JSON.parse(contentPartCopy.input);
} catch {
// no-op
}
}

// TODO: Fix when SDK types are fixed
return {
...contentPart,
...contentPartCopy,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any;
} else {
Expand Down Expand Up @@ -519,7 +574,9 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): {
const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) =>
content.find(
(contentPart) =>
contentPart.type === "tool_use" && contentPart.id === toolCall.id
(contentPart.type === "tool_use" ||
contentPart.type === "input_json_delta") &&
contentPart.id === toolCall.id
)
);
if (hasMismatchedToolCalls) {
Expand Down Expand Up @@ -581,12 +638,16 @@ function extractToolCallChunk(
) {
if (typeof inputJsonDeltaChunks.input === "string") {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: inputJsonDeltaChunks.input,
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
};
} else {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: JSON.stringify(inputJsonDeltaChunks.input, null, 2),
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
Expand Down Expand Up @@ -919,6 +980,14 @@ export class ChatAnthropicMessages<
let usageData = { input_tokens: 0, output_tokens: 0 };

let concatenatedChunks: AIMessageChunk | undefined;
// Anthropic only yields the tool name and id once, so we need to save those
// so we can yield them with the rest of the tool_use content.
let toolUse:
| {
id: string;
name: string;
}
| undefined;

for await (const data of stream) {
if (options.signal?.aborted) {
Expand All @@ -930,12 +999,27 @@ export class ChatAnthropicMessages<
streamUsage: !!(this.streamUsage || options.streamUsage),
coerceContentToString,
usageData,
toolUse: toolUse
? {
id: toolUse.id,
name: toolUse.name,
}
: undefined,
});
if (!result) continue;

const { chunk, usageData: updatedUsageData } = result;
const {
chunk,
usageData: updatedUsageData,
toolUse: updatedToolUse,
} = result;

usageData = updatedUsageData;

if (updatedToolUse) {
toolUse = updatedToolUse;
}

const newToolCallChunk = extractToolCallChunk(chunk);
// Maintain concatenatedChunks for accessing the complete `tool_use` content block.
concatenatedChunks = concatenatedChunks
Expand Down
170 changes: 169 additions & 1 deletion libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
getBufferString,
} from "@langchain/core/messages";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
import { StructuredTool, tool } from "@langchain/core/tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { RunnableLambda } from "@langchain/core/runnables";
import { concat } from "@langchain/core/utils/stream";
import {
BaseChatModelsTests,
BaseChatModelsTestsFields,
Expand Down Expand Up @@ -496,7 +497,7 @@
],
});
const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({} as any);

Check warning on line 500 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

// Invoke the model to trigger a cache update.
await model.invoke([humanMessage]);
Expand All @@ -522,6 +523,159 @@
expect(cacheValue2).toEqual(cacheValue);
}

/**
* This test verifies models can invoke a tool, and use the AIMessage
* with the tool call in a followup request. This is useful when building
* agents, or other pipelines that invoke tools.
*/
async testModelCanUseToolUseAIMessage() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const result: AIMessage = await modelWithTools.invoke(messages);

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}
const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalResult = await modelWithTools.invoke(messages);

expect(finalResult.content).not.toBe("");
}

/**
* Same as the above test, but streaming both model invocations.
*/
async testModelCanUseToolUseAIMessageWithStreaming() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const stream = await modelWithTools.stream(messages);
let result: AIMessageChunk | undefined;
for await (const chunk of stream) {
result = !result ? chunk : concat(result, chunk);
}

expect(result).toBeDefined();
if (!result) return;

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}

const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalStream = await modelWithTools.stream(messages);
let finalResult: AIMessageChunk | undefined;
for await (const chunk of finalStream) {
finalResult = !finalResult ? chunk : concat(finalResult, chunk);
}

expect(finalResult).toBeDefined();
if (!finalResult) return;

expect(finalResult.content).not.toBe("");
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand All @@ -533,42 +687,42 @@

try {
await this.testInvoke();
} catch (e: any) {

Check warning on line 690 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testInvoke failed", e);
}

try {
await this.testStream();
} catch (e: any) {

Check warning on line 697 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testStream failed", e);
}

try {
await this.testBatch();
} catch (e: any) {

Check warning on line 704 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testBatch failed", e);
}

try {
await this.testConversation();
} catch (e: any) {

Check warning on line 711 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testConversation failed", e);
}

try {
await this.testUsageMetadata();
} catch (e: any) {

Check warning on line 718 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadata failed", e);
}

try {
await this.testUsageMetadataStreaming();
} catch (e: any) {

Check warning on line 725 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadataStreaming failed", e);
}
Expand Down Expand Up @@ -629,6 +783,20 @@
console.error("testCacheComplexMessageTypes failed", e);
}

try {
await this.testModelCanUseToolUseAIMessage();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessage failed", e);
}

try {
await this.testModelCanUseToolUseAIMessageWithStreaming();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessageWithStreaming failed", e);
}

return allTestsPassed;
}
}
Loading