From fece0e35093afe835f7452ef181b93b48c1f2a0b Mon Sep 17 00:00:00 2001 From: Mikko Haapoja Date: Tue, 9 May 2023 17:54:38 -0400 Subject: [PATCH] Make ChatConversationalAgentOutputParser be more optimistic and parse JSON and then attempt to find JSON in text --- .../src/agents/chat_convo/outputParser.ts | 62 ++++++++++++---- .../agents/tests/chat_output_parser.test.ts | 71 +++++++++++++++++++ 2 files changed, 121 insertions(+), 12 deletions(-) diff --git a/langchain/src/agents/chat_convo/outputParser.ts b/langchain/src/agents/chat_convo/outputParser.ts index a60578d02102..af40fff8ae68 100644 --- a/langchain/src/agents/chat_convo/outputParser.ts +++ b/langchain/src/agents/chat_convo/outputParser.ts @@ -3,29 +3,67 @@ import { FORMAT_INSTRUCTIONS } from "./prompt.js"; export class ChatConversationalAgentOutputParser extends AgentActionOutputParser { async parse(text: string) { - let jsonOutput = text.trim(); - if (jsonOutput.includes("```json")) { - jsonOutput = jsonOutput.split("```json")[1].trimStart(); - } else if (jsonOutput.includes("```")) { - const firstIndex = jsonOutput.indexOf("```"); - jsonOutput = jsonOutput.slice(firstIndex + 3).trimStart(); - } - const lastIndex = jsonOutput.lastIndexOf("```"); - if (lastIndex !== -1) { - jsonOutput = jsonOutput.slice(0, lastIndex).trimEnd(); + const trimmedText = text.trim(); + let action: string | undefined; + let action_input: string | undefined; + + try { + ({ action, action_input } = JSON.parse(trimmedText)); + } catch (_error) { + ({ action, action_input } = this.findActionAndInput(trimmedText)); } - const response = JSON.parse(jsonOutput); + if (!action) { + throw new Error(`\`action\` could not be found in: "${trimmedText}"`); + } - const { action, action_input } = response; + if (!action_input) { + throw new Error( + `\`action_input\` could not be found in: "${trimmedText}"` + ); + } if (action === "Final Answer") { return { returnValues: { output: action_input }, log: text }; } + return { tool: action, toolInput: action_input, log: text }; } getFormatInstructions(): string { return FORMAT_INSTRUCTIONS; } + + private findActionAndInput(text: string): { + action: string | undefined; + action_input: string | undefined; + } { + const jsonOutput = this.findJson(text); + + try { + const response = JSON.parse(jsonOutput); + + return response; + } catch (error) { + return { action: undefined, action_input: undefined }; + } + } + + private findJson(text: string) { + let jsonOutput = text; + + if (jsonOutput.includes("```json")) { + jsonOutput = jsonOutput.split("```json")[1].trimStart(); + } else if (jsonOutput.includes("```")) { + const firstIndex = jsonOutput.indexOf("```"); + jsonOutput = jsonOutput.slice(firstIndex + 3).trimStart(); + } + const lastIndex = jsonOutput.lastIndexOf("```"); + + if (lastIndex !== -1) { + jsonOutput = jsonOutput.slice(0, lastIndex).trimEnd(); + } + + return jsonOutput; + } } diff --git a/langchain/src/agents/tests/chat_output_parser.test.ts b/langchain/src/agents/tests/chat_output_parser.test.ts index cbadb353663e..445af4afe47a 100644 --- a/langchain/src/agents/tests/chat_output_parser.test.ts +++ b/langchain/src/agents/tests/chat_output_parser.test.ts @@ -42,6 +42,14 @@ test("Can parse JSON with text in front of it", async () => { toolInput: "```sql\nSELECT * FROM orders\nJOIN users ON users.id = orders.user_id\nWHERE users.email = 'bud'```", }, + { + input: + '{"action":"ToolWithJson","action_input":"The tool input ```json\\n{\\"yes\\":true}\\n```"}', + output: + '{"action":"ToolWithJson","action_input":"The tool input ```json\\n{\\"yes\\":true}\\n```"}', + tool: "ToolWithJson", + toolInput: 'The tool input ```json\n{"yes":true}\n```', + }, ]; const p = new ChatConversationalAgentOutputParser(); @@ -62,3 +70,66 @@ test("Can parse JSON with text in front of it", async () => { } } }); + +test("will throw exceptions if action or action_input are not found", async () => { + const parser = new ChatConversationalAgentOutputParser(); + + type MissingItem = "action" | "action_input"; + type TestCase = { message: string; missing: MissingItem }; + + const testCases: TestCase[] = [ + { + message: "", + missing: "action", + }, + { + message: '{"action": "Final Answer"}', + missing: "action_input", + }, + { + message: '{"action_input": "I have no action"}', + missing: "action", + }, + { + message: + 'I have a prefix ```json\n{"action_input": "I have no action"}```', + missing: "action", + }, + { + message: 'I have a prefix ```{"action_input": "I have no action"}```', + missing: "action", + }, + { + message: + 'I have a prefix ```json\n{"action_input": "I have no action"}\n```', + missing: "action", + }, + { + message: 'I have a prefix ```\n{"action_input": "I have no action"}\n```', + missing: "action", + }, + + { + message: 'I have a prefix ```json\n{"action": "ToolThing"}```', + missing: "action_input", + }, + { + message: 'I have a prefix ```{"action": "ToolThing"}```', + missing: "action_input", + }, + { + message: 'I have a prefix ```json\n{"action": "ToolThing"}\n```', + missing: "action_input", + }, + { + message: 'I have a prefix ```\n{"action": "ToolThing"}\n```', + missing: "action_input", + }, + ]; + + for (const { message, missing } of testCases) { + await expect(parser.parse(message)).rejects.toThrow( + `\`${missing}\` could not be found in: "${message}"` + ); + } +});