From 5fa710cbb520cde5163b358cd04db36421268bc6 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Tue, 17 Sep 2024 14:14:59 -0700 Subject: [PATCH] fix(core): Fix nested stream events behavior (#6836) --- langchain-core/src/runnables/base.ts | 4 +- .../tests/runnable_stream_events_v2.test.ts | 71 ++++++++++++++++++- langchain-core/src/tracers/event_stream.ts | 2 +- 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index 9caa530d699c..59e428b1cb88 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -695,7 +695,7 @@ export abstract class Runnable< config.callbacks = callbacks.concat([logStreamCallbackHandler]); } else { const copiedCallbacks = callbacks.copy(); - copiedCallbacks.inheritableHandlers.push(logStreamCallbackHandler); + copiedCallbacks.addHandler(logStreamCallbackHandler, true); // eslint-disable-next-line no-param-reassign config.callbacks = copiedCallbacks; } @@ -896,7 +896,7 @@ export abstract class Runnable< config.callbacks = callbacks.concat(eventStreamer); } else { const copiedCallbacks = callbacks.copy(); - copiedCallbacks.inheritableHandlers.push(eventStreamer); + copiedCallbacks.addHandler(eventStreamer, true); // eslint-disable-next-line no-param-reassign config.callbacks = copiedCallbacks; } diff --git a/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts b/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts index 1702d226aa4b..2807b4935657 100644 --- a/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts +++ b/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts @@ -4,6 +4,7 @@ import { test, expect, afterEach } from "@jest/globals"; import { z } from "zod"; +import { AsyncLocalStorage } from "node:async_hooks"; import { RunnableLambda, RunnableMap, @@ -28,8 +29,9 @@ import { DynamicStructuredTool, DynamicTool, tool } from "../../tools/index.js"; import { Document } from "../../documents/document.js"; import { PromptTemplate } from "../../prompts/prompt.js"; import { GenerationChunk } from "../../outputs.js"; -// Import from web to avoid side-effects from AsyncLocalStorage +// Import from web to avoid top-level side-effects from AsyncLocalStorage import { dispatchCustomEvent } from "../../callbacks/dispatch/web.js"; +import { AsyncLocalStorageProviderSingleton } from "../../singletons/index.js"; function reverse(s: string) { // Reverse a string. @@ -138,6 +140,73 @@ test("Runnable streamEvents method on a chat model", async () => { ]); }); +test("Runnable streamEvents call nested in another runnable + passed callbacks should still work", async () => { + AsyncLocalStorageProviderSingleton.initializeGlobalInstance( + new AsyncLocalStorage() + ); + + const model = new FakeListChatModel({ + responses: ["abc"], + }); + + const events: any[] = []; + const container = RunnableLambda.from(async (_) => { + const eventStream = model.streamEvents("hello", { version: "v2" }); + for await (const event of eventStream) { + events.push(event); + } + return events; + }); + + await container.invoke({}, { callbacks: [{ handleLLMStart: () => {} }] }); + + // used here to avoid casting every ID + const anyString = expect.any(String) as unknown as string; + + expect(events).toMatchObject([ + { + data: { input: "hello" }, + event: "on_chat_model_start", + name: "FakeListChatModel", + metadata: expect.any(Object), + run_id: expect.any(String), + tags: [], + }, + { + data: { chunk: new AIMessageChunk({ id: anyString, content: "a" }) }, + event: "on_chat_model_stream", + name: "FakeListChatModel", + metadata: expect.any(Object), + run_id: expect.any(String), + tags: [], + }, + { + data: { chunk: new AIMessageChunk({ id: anyString, content: "b" }) }, + event: "on_chat_model_stream", + name: "FakeListChatModel", + metadata: expect.any(Object), + run_id: expect.any(String), + tags: [], + }, + { + data: { chunk: new AIMessageChunk({ id: anyString, content: "c" }) }, + event: "on_chat_model_stream", + name: "FakeListChatModel", + metadata: expect.any(Object), + run_id: expect.any(String), + tags: [], + }, + { + data: { output: new AIMessageChunk({ id: anyString, content: "abc" }) }, + event: "on_chat_model_end", + name: "FakeListChatModel", + metadata: expect.any(Object), + run_id: expect.any(String), + tags: [], + }, + ]); +}); + test("Runnable streamEvents method with three runnables", async () => { const r = RunnableLambda.from(reverse); diff --git a/langchain-core/src/tracers/event_stream.ts b/langchain-core/src/tracers/event_stream.ts index 015faa20e5c8..3972e7ce9b4b 100644 --- a/langchain-core/src/tracers/event_stream.ts +++ b/langchain-core/src/tracers/event_stream.ts @@ -364,7 +364,7 @@ export class EventStreamCallbackHandler extends BaseTracer { throw new Error(`onLLMNewToken: Run ID ${run.id} not found in run map.`); } // Top-level streaming events are covered by tapOutputIterable - if (run.parent_run_id === undefined) { + if (this.runInfoMap.size === 1) { return; } if (runInfo.runType === "chat_model") {