Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 19 additions & 32 deletions libs/langchain/src/agents/ReactAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ import type { Runnable, RunnableConfig } from "@langchain/core/runnables";
import type { StreamEvent } from "@langchain/core/tracers/log_stream";

import { createAgentAnnotationConditional } from "./annotation.js";
import {
isClientTool,
validateLLMHasNoBoundTools,
wrapToolCall,
} from "./utils.js";
import { isClientTool, validateLLMHasNoBoundTools } from "./utils.js";

import { AgentNode } from "./nodes/AgentNode.js";
import { ToolNode } from "./nodes/ToolNode.js";
Expand All @@ -36,6 +32,7 @@ import {
initializeMiddlewareStates,
parseJumpToTarget,
} from "./nodes/utils.js";
import { StateManager } from "./state.js";

import type { WithStateGraphNodes } from "./types.js";
import type { ClientTool, ServerTool } from "./tools.js";
Expand Down Expand Up @@ -129,6 +126,8 @@ export class ReactAgent<

#agentNode: AgentNode<any, AnyAnnotationRoot>;

#stateManager = new StateManager();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the #stateManager properly persist in checkpointer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it needs to. It is just s small helper class to help merge the state between 4 nodes into a single agent state, e.g. if you change state in beforeModel it makes sure that afterModel gets these updates.


constructor(
public options: CreateAgentParams<
StructuredResponseFormat,
Expand Down Expand Up @@ -245,21 +244,12 @@ export class ReactAgent<
throw new Error(`Middleware ${m.name} is defined multiple times`);
}

const getState = () => {
return {
...beforeAgentNode?.getState(),
...beforeModelNode?.getState(),
...afterModelNode?.getState(),
...afterAgentNode?.getState(),
...this.#agentNode.getState(),
};
};

middlewareNames.add(m.name);
if (m.beforeAgent) {
beforeAgentNode = new BeforeAgentNode(m, {
getState,
getState: () => this.#stateManager.getState(m.name),
});
this.#stateManager.addNode(m, beforeAgentNode);
const name = `${m.name}.before_agent`;
beforeAgentNodes.push({
index: i,
Expand All @@ -274,8 +264,9 @@ export class ReactAgent<
}
if (m.beforeModel) {
beforeModelNode = new BeforeModelNode(m, {
getState,
getState: () => this.#stateManager.getState(m.name),
});
this.#stateManager.addNode(m, beforeModelNode);
const name = `${m.name}.before_model`;
beforeModelNodes.push({
index: i,
Expand All @@ -290,8 +281,9 @@ export class ReactAgent<
}
if (m.afterModel) {
afterModelNode = new AfterModelNode(m, {
getState,
getState: () => this.#stateManager.getState(m.name),
});
this.#stateManager.addNode(m, afterModelNode);
const name = `${m.name}.after_model`;
afterModelNodes.push({
index: i,
Expand All @@ -306,8 +298,9 @@ export class ReactAgent<
}
if (m.afterAgent) {
afterAgentNode = new AfterAgentNode(m, {
getState,
getState: () => this.#stateManager.getState(m.name),
});
this.#stateManager.addNode(m, afterAgentNode);
const name = `${m.name}.after_agent`;
afterAgentNodes.push({
index: i,
Expand All @@ -322,32 +315,26 @@ export class ReactAgent<
}

if (m.wrapModelCall) {
wrapModelCallHookMiddleware.push([m, getState]);
wrapModelCallHookMiddleware.push([
m,
() => this.#stateManager.getState(m.name),
]);
}
}

/**
* Add Nodes
*/
allNodeWorkflows.addNode(
"model_request",
this.#agentNode,
AgentNode.nodeOptions
);

/**
* Collect and compose wrapToolCall handlers from middleware
* Wrap each handler with error handling and validation
*/
const wrapToolCallHandler = wrapToolCall(middleware);
allNodeWorkflows.addNode("model_request", this.#agentNode);

/**
* add single tool node for all tools
*/
if (toolClasses.filter(isClientTool).length > 0) {
const toolNode = new ToolNode(toolClasses.filter(isClientTool), {
signal: this.options.signal,
wrapToolCall: wrapToolCallHandler,
middleware,
stateManager: this.#stateManager,
});
allNodeWorkflows.addNode("tools", toolNode);
}
Expand Down
25 changes: 8 additions & 17 deletions libs/langchain/src/agents/nodes/AgentNode.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/* eslint-disable no-instanceof/no-instanceof */
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
import { BaseMessage, AIMessage, ToolMessage } from "@langchain/core/messages";
import { z } from "zod/v3";
import { Command, type LangGraphRunnableConfig } from "@langchain/langgraph";
import { type LanguageModelLike } from "@langchain/core/language_models/base";
import { type BaseChatModelCallOptions } from "@langchain/core/language_models/chat_models";
import {
InteropZodObject,
getSchemaDescription,
interopParse,
interopZodObjectPartial,
} from "@langchain/core/utils/types";
import type { ToolCall } from "@langchain/core/messages/tool";

Expand Down Expand Up @@ -403,6 +403,12 @@ export class AgentNode<
> = {
...request,
state: {
...(middleware.stateSchema
? interopParse(
interopZodObjectPartial(middleware.stateSchema),
state
)
: {}),
...currentGetState(),
messages: state.messages,
} as InternalAgentState<StructuredResponseFormat> &
Expand Down Expand Up @@ -510,10 +516,7 @@ export class AgentNode<
systemPrompt: this.#options.systemPrompt,
messages: state.messages,
tools: this.#options.toolClasses,
state: {
messages: state.messages,
} as InternalAgentState<StructuredResponseFormat> &
PreHookAnnotation["State"],
state,
runtime: Object.freeze({
context: lgConfig?.context,
writer: lgConfig.writer,
Expand Down Expand Up @@ -814,18 +817,6 @@ export class AgentNode<
return modelRunnable;
}

static get nodeOptions(): {
input: z.ZodObject<{
messages: z.ZodArray<z.ZodType<BaseMessage>>;
}>;
} {
return {
input: z.object({
messages: z.array(z.custom<BaseMessage>()),
}),
};
}

getState(): {
messages: BaseMessage[];
} {
Expand Down
45 changes: 31 additions & 14 deletions libs/langchain/src/agents/nodes/ToolNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ import {
import { RunnableCallable } from "../RunnableCallable.js";
import { PreHookAnnotation } from "../annotation.js";
import { mergeAbortSignals } from "./utils.js";
import { wrapToolCall } from "../utils.js";
import { ToolInvocationError } from "../errors.js";
import type {
AnyAnnotationRoot,
WrapToolCallHook,
ToolCallRequest,
ToAnnotationRoot,
} from "../middleware/types.js";
import type { AgentMiddleware } from "../middleware/types.js";
import type { StateManager } from "../state.js";

export interface ToolNodeOptions {
/**
Expand Down Expand Up @@ -64,11 +66,13 @@ export interface ToolNodeOptions {
| boolean
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined);
/**
* Optional wrapper function for tool execution.
* Allows middleware to intercept and modify tool calls before execution.
* The wrapper receives the tool call request and a handler function to execute the tool.
* The middleware to use for tool execution.
*/
wrapToolCall?: WrapToolCallHook;
middleware?: readonly AgentMiddleware[];
/**
* The state manager to use for tool execution.
*/
stateManager?: StateManager;
}

const isBaseMessageArray = (input: unknown): input is BaseMessage[] =>
Expand Down Expand Up @@ -175,13 +179,16 @@ export class ToolNode<
| ((error: unknown, toolCall: ToolCall) => ToolMessage | undefined) =
defaultHandleToolErrors;

wrapToolCall?: WrapToolCallHook;
middleware: readonly AgentMiddleware[] = [];

stateManager?: StateManager;

constructor(
tools: (StructuredToolInterface | DynamicTool | RunnableToolLike)[],
public options?: ToolNodeOptions
) {
const { name, tags, handleToolErrors, wrapToolCall } = options ?? {};
const { name, tags, handleToolErrors, middleware, stateManager, signal } =
options ?? {};
super({
name,
tags,
Expand All @@ -194,8 +201,9 @@ export class ToolNode<
});
this.tools = tools;
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
this.wrapToolCall = wrapToolCall;
this.signal = options?.signal;
this.middleware = middleware ?? [];
this.signal = signal;
this.stateManager = stateManager;
}

/**
Expand Down Expand Up @@ -271,7 +279,7 @@ export class ToolNode<
protected async runTool(
call: ToolCall,
config: RunnableConfig,
state?: ToAnnotationRoot<StateSchema>["State"] & PreHookAnnotation["State"]
state: ToAnnotationRoot<StateSchema>["State"] & PreHookAnnotation["State"]
): Promise<ToolMessage | Command> {
/**
* Define the base handler that executes the tool.
Expand Down Expand Up @@ -343,16 +351,24 @@ export class ToolNode<
const request = {
toolCall: call,
tool,
state: state || ({} as any),
state,
runtime,
};

/**
* Collect and compose wrapToolCall handlers from middleware
* Wrap each handler with error handling and validation
*/
const wrapToolCallHandler = this.stateManager
? wrapToolCall(this.middleware, state)
: undefined;

/**
* If wrapToolCall is provided, use it to wrap the tool execution
*/
if (this.wrapToolCall && state) {
if (wrapToolCallHandler) {
try {
return await this.wrapToolCall(request, baseHandler);
return await wrapToolCallHandler(request, baseHandler);
} catch (e: unknown) {
/**
* Handle middleware errors
Expand Down Expand Up @@ -381,7 +397,8 @@ export class ToolNode<
let outputs: (ToolMessage | Command)[];

if (isSendInput(state)) {
outputs = [await this.runTool(state.lg_tool_call, config, state)];
const { lg_tool_call, jumpTo, ...newState } = state;
outputs = [await this.runTool(state.lg_tool_call, config, newState)];
} else {
let messages: BaseMessage[];
if (isBaseMessageArray(state)) {
Expand Down
Loading