diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 5e2d982f..4aab5393 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -258,9 +258,7 @@ export class RunTree implements BaseRun { child_execution_order: child_execution_order, }); - // If a context var is set by LangChain outside of a traceable, - // it will be an object with a single property and we should copy - // context vars over into the new run tree. + // Copy context vars over into the new run tree. if (_LC_CONTEXT_VARIABLES_KEY in this) { // eslint-disable-next-line @typescript-eslint/no-explicit-any (child as any)[_LC_CONTEXT_VARIABLES_KEY] = diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index 7842260f..4d4951b9 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -1,9 +1,14 @@ import { jest } from "@jest/globals"; -import { RunTree, RunTreeConfig } from "../run_trees.js"; +import { + _LC_CONTEXT_VARIABLES_KEY, + RunTree, + RunTreeConfig, +} from "../run_trees.js"; import { ROOT, traceable, withRunTree } from "../traceable.js"; import { getAssumedTreeFromCalls } from "./utils/tree.js"; import { mockClient } from "./utils/mock_client.js"; import { Client, overrideFetchImplementation } from "../index.js"; +import { AsyncLocalStorageProviderSingleton } from "../singletons/traceable.js"; test("basic traceable implementation", async () => { const { client, callSpy } = mockClient(); @@ -103,6 +108,80 @@ test("nested traceable implementation", async () => { }); }); +test.only("nested traceable passes through LangChain context vars", (done) => { + const alsInstance = AsyncLocalStorageProviderSingleton.getInstance(); + + alsInstance.run( + { + [_LC_CONTEXT_VARIABLES_KEY]: { foo: "bar" }, + } as any, + // eslint-disable-next-line @typescript-eslint/no-misused-promises + async () => { + try { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + const { client, callSpy } = mockClient(); + + const llm = traceable(async function llm(input: string) { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + return input.repeat(2); + }); + + const str = traceable(async function* str(input: string) { + const response = input.split("").reverse(); + for (const char of response) { + yield char; + } + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + }); + + const chain = traceable( + async function chain(input: string) { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + const question = await llm(input); + + let answer = ""; + for await (const char of str(question)) { + answer += char; + } + + return { question, answer }; + }, + { client, tracingEnabled: true } + ); + + const result = await chain("Hello world"); + + expect(result).toEqual({ + question: "Hello worldHello world", + answer: "dlrow olleHdlrow olleH", + }); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["chain:0", "llm:1", "str:2"], + edges: [ + ["chain:0", "llm:1"], + ["chain:0", "str:2"], + ], + }); + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + done(); + } catch (e) { + done(e); + } + } + ); +}); + test("trace circular input and output objects", async () => { const { client, callSpy } = mockClient(); const a: Record = {};