Skip to content

Commit

Permalink
fix(core): Fix nested stream events behavior (langchain-ai#6836)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored and FilipZmijewski committed Sep 27, 2024
1 parent 6bf6f48 commit 5fa710c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
4 changes: 2 additions & 2 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ export abstract class Runnable<
config.callbacks = callbacks.concat([logStreamCallbackHandler]);
} else {
const copiedCallbacks = callbacks.copy();
copiedCallbacks.inheritableHandlers.push(logStreamCallbackHandler);
copiedCallbacks.addHandler(logStreamCallbackHandler, true);
// eslint-disable-next-line no-param-reassign
config.callbacks = copiedCallbacks;
}
Expand Down Expand Up @@ -896,7 +896,7 @@ export abstract class Runnable<
config.callbacks = callbacks.concat(eventStreamer);
} else {
const copiedCallbacks = callbacks.copy();
copiedCallbacks.inheritableHandlers.push(eventStreamer);
copiedCallbacks.addHandler(eventStreamer, true);
// eslint-disable-next-line no-param-reassign
config.callbacks = copiedCallbacks;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import { test, expect, afterEach } from "@jest/globals";
import { z } from "zod";
import { AsyncLocalStorage } from "node:async_hooks";
import {
RunnableLambda,
RunnableMap,
Expand All @@ -28,8 +29,9 @@ import { DynamicStructuredTool, DynamicTool, tool } from "../../tools/index.js";
import { Document } from "../../documents/document.js";
import { PromptTemplate } from "../../prompts/prompt.js";
import { GenerationChunk } from "../../outputs.js";
// Import from web to avoid side-effects from AsyncLocalStorage
// Import from web to avoid top-level side-effects from AsyncLocalStorage
import { dispatchCustomEvent } from "../../callbacks/dispatch/web.js";
import { AsyncLocalStorageProviderSingleton } from "../../singletons/index.js";

function reverse(s: string) {
// Reverse a string.
Expand Down Expand Up @@ -138,6 +140,73 @@ test("Runnable streamEvents method on a chat model", async () => {
]);
});

test("Runnable streamEvents call nested in another runnable + passed callbacks should still work", async () => {
AsyncLocalStorageProviderSingleton.initializeGlobalInstance(
new AsyncLocalStorage()
);

const model = new FakeListChatModel({
responses: ["abc"],
});

const events: any[] = [];
const container = RunnableLambda.from(async (_) => {
const eventStream = model.streamEvents("hello", { version: "v2" });
for await (const event of eventStream) {
events.push(event);
}
return events;
});

await container.invoke({}, { callbacks: [{ handleLLMStart: () => {} }] });

// used here to avoid casting every ID
const anyString = expect.any(String) as unknown as string;

expect(events).toMatchObject([
{
data: { input: "hello" },
event: "on_chat_model_start",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ id: anyString, content: "a" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ id: anyString, content: "b" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ id: anyString, content: "c" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { output: new AIMessageChunk({ id: anyString, content: "abc" }) },
event: "on_chat_model_end",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
]);
});

test("Runnable streamEvents method with three runnables", async () => {
const r = RunnableLambda.from(reverse);

Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/tracers/event_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ export class EventStreamCallbackHandler extends BaseTracer {
throw new Error(`onLLMNewToken: Run ID ${run.id} not found in run map.`);
}
// Top-level streaming events are covered by tapOutputIterable
if (run.parent_run_id === undefined) {
if (this.runInfoMap.size === 1) {
return;
}
if (runInfo.runType === "chat_model") {
Expand Down

0 comments on commit 5fa710c

Please sign in to comment.