Skip to content

Commit a2e86de

Browse files
authored
feat(langgraph): populate task results (#1297)
1 parent 6271906 commit a2e86de

File tree

6 files changed

+278
-173
lines changed

6 files changed

+278
-173
lines changed

libs/langgraph/src/pregel/debug.test.ts

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { wrap, tasksWithWrites, _readChannels } from "./debug.js";
33
import { BaseChannel } from "../channels/base.js";
44
import { LastValue } from "../channels/last_value.js";
55
import { EmptyChannelError } from "../errors.js";
6+
import { ERROR, INTERRUPT, PULL } from "../constants.js";
67

78
describe("wrap", () => {
89
it("should wrap text with color codes", () => {
@@ -107,24 +108,27 @@ describe("tasksWithWrites", () => {
107108
{
108109
id: "task1",
109110
name: "Task 1",
110-
path: ["PULL", "Task 1"] as ["PULL", string],
111+
path: [PULL, "Task 1"] as [typeof PULL, string],
111112
interrupts: [],
112113
},
113114
{
114115
id: "task2",
115116
name: "Task 2",
116-
path: ["PULL", "Task 2"] as ["PULL", string],
117+
path: [PULL, "Task 2"] as [typeof PULL, string],
117118
interrupts: [],
118119
},
119120
];
120121

121122
const pendingWrites: Array<[string, string, unknown]> = [];
122123

123-
const result = tasksWithWrites(tasks, pendingWrites);
124+
const result = tasksWithWrites(tasks, pendingWrites, undefined, [
125+
"Task 1",
126+
"Task 2",
127+
]);
124128

125129
expect(result).toEqual([
126-
{ id: "task1", name: "Task 1", path: ["PULL", "Task 1"], interrupts: [] },
127-
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
130+
{ id: "task1", name: "Task 1", path: [PULL, "Task 1"], interrupts: [] },
131+
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
128132
]);
129133
});
130134

@@ -133,32 +137,35 @@ describe("tasksWithWrites", () => {
133137
{
134138
id: "task1",
135139
name: "Task 1",
136-
path: ["PULL", "Task 1"] as ["PULL", string],
140+
path: [PULL, "Task 1"] as [typeof PULL, string],
137141
interrupts: [],
138142
},
139143
{
140144
id: "task2",
141145
name: "Task 2",
142-
path: ["PULL", "Task 2"] as ["PULL", string],
146+
path: [PULL, "Task 2"] as [typeof PULL, string],
143147
interrupts: [],
144148
},
145149
];
146150

147151
const pendingWrites: Array<[string, string, unknown]> = [
148-
["task1", "__error__", { message: "Test error" }],
152+
["task1", ERROR, { message: "Test error" }],
149153
];
150154

151-
const result = tasksWithWrites(tasks, pendingWrites);
155+
const result = tasksWithWrites(tasks, pendingWrites, undefined, [
156+
"Task 1",
157+
"Task 2",
158+
]);
152159

153160
expect(result).toEqual([
154161
{
155162
id: "task1",
156163
name: "Task 1",
157-
path: ["PULL", "Task 1"],
164+
path: [PULL, "Task 1"],
158165
error: { message: "Test error" },
159166
interrupts: [],
160167
},
161-
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
168+
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
162169
]);
163170
});
164171

@@ -167,13 +174,13 @@ describe("tasksWithWrites", () => {
167174
{
168175
id: "task1",
169176
name: "Task 1",
170-
path: ["PULL", "Task 1"] as ["PULL", string],
177+
path: [PULL, "Task 1"] as [typeof PULL, string],
171178
interrupts: [],
172179
},
173180
{
174181
id: "task2",
175182
name: "Task 2",
176-
path: ["PULL", "Task 2"] as ["PULL", string],
183+
path: [PULL, "Task 2"] as [typeof PULL, string],
177184
interrupts: [],
178185
},
179186
];
@@ -184,17 +191,20 @@ describe("tasksWithWrites", () => {
184191
task1: { configurable: { key: "value" } },
185192
};
186193

187-
const result = tasksWithWrites(tasks, pendingWrites, states);
194+
const result = tasksWithWrites(tasks, pendingWrites, states, [
195+
"Task 1",
196+
"Task 2",
197+
]);
188198

189199
expect(result).toEqual([
190200
{
191201
id: "task1",
192202
name: "Task 1",
193-
path: ["PULL", "Task 1"],
203+
path: [PULL, "Task 1"],
194204
interrupts: [],
195205
state: { configurable: { key: "value" } },
196206
},
197-
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
207+
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
198208
]);
199209
});
200210

@@ -203,24 +213,81 @@ describe("tasksWithWrites", () => {
203213
{
204214
id: "task1",
205215
name: "Task 1",
206-
path: ["PULL", "Task 1"] as ["PULL", string],
216+
path: [PULL, "Task 1"] as [typeof PULL, string],
207217
interrupts: [],
208218
},
209219
];
210220

211221
const pendingWrites: Array<[string, string, unknown]> = [
212-
["task1", "__interrupt__", { value: "Interrupted", when: "during" }],
222+
["task1", INTERRUPT, { value: "Interrupted", when: "during" }],
213223
];
214224

215-
const result = tasksWithWrites(tasks, pendingWrites);
225+
const result = tasksWithWrites(tasks, pendingWrites, undefined, ["task1"]);
216226

217227
expect(result).toEqual([
218228
{
219229
id: "task1",
220230
name: "Task 1",
221-
path: ["PULL", "Task 1"],
231+
path: [PULL, "Task 1"],
222232
interrupts: [{ value: "Interrupted", when: "during" }],
223233
},
224234
]);
225235
});
236+
237+
it("should include results", () => {
238+
const tasks = [
239+
{
240+
id: "task1",
241+
name: "Task 1",
242+
path: [PULL, "Task 1"] as [typeof PULL, string],
243+
interrupts: [],
244+
},
245+
{
246+
id: "task2",
247+
name: "Task 2",
248+
path: [PULL, "Task 2"] as [typeof PULL, string],
249+
interrupts: [],
250+
},
251+
{
252+
id: "task3",
253+
name: "Task 3",
254+
path: [PULL, "Task 3"] as [typeof PULL, string],
255+
interrupts: [],
256+
},
257+
];
258+
259+
const pendingWrites: Array<[string, string, unknown]> = [
260+
["task1", "Task 1", "Result"],
261+
["task2", "Task 2", "Result 2"],
262+
];
263+
264+
const result = tasksWithWrites(tasks, pendingWrites, undefined, [
265+
"Task 1",
266+
"Task 2",
267+
]);
268+
269+
expect(result).toEqual([
270+
{
271+
id: "task1",
272+
name: "Task 1",
273+
path: [PULL, "Task 1"],
274+
interrupts: [],
275+
result: { "Task 1": "Result" },
276+
},
277+
{
278+
id: "task2",
279+
name: "Task 2",
280+
path: [PULL, "Task 2"],
281+
interrupts: [],
282+
result: { "Task 2": "Result 2" },
283+
},
284+
{
285+
id: "task3",
286+
name: "Task 3",
287+
path: [PULL, "Task 3"],
288+
interrupts: [],
289+
result: undefined,
290+
},
291+
]);
292+
});
226293
});

libs/langgraph/src/pregel/debug.ts

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ import {
55
PendingWrite,
66
} from "@langchain/langgraph-checkpoint";
77
import { BaseChannel } from "../channels/base.js";
8-
import { ERROR, Interrupt, INTERRUPT, TAG_HIDDEN } from "../constants.js";
8+
import {
9+
ERROR,
10+
Interrupt,
11+
INTERRUPT,
12+
RETURN,
13+
TAG_HIDDEN,
14+
} from "../constants.js";
915
import { EmptyChannelError } from "../errors.js";
1016
import {
1117
PregelExecutableTask,
@@ -140,6 +146,8 @@ export function* mapDebugTaskResults<
140146
}
141147
}
142148

149+
type ChannelKey = string | number | symbol;
150+
143151
export function* mapDebugCheckpoint<
144152
N extends PropertyKey,
145153
C extends PropertyKey
@@ -151,7 +159,8 @@ export function* mapDebugCheckpoint<
151159
metadata: CheckpointMetadata,
152160
tasks: readonly PregelExecutableTask<N, C>[],
153161
pendingWrites: CheckpointPendingWrite[],
154-
parentConfig: RunnableConfig | undefined
162+
parentConfig: RunnableConfig | undefined,
163+
outputKeys: ChannelKey | ChannelKey[]
155164
) {
156165
function formatConfig(config: RunnableConfig) {
157166
// https://stackoverflow.com/a/78298178
@@ -214,7 +223,7 @@ export function* mapDebugCheckpoint<
214223
values: readChannels(channels, streamChannels),
215224
metadata,
216225
next: tasks.map((task) => task.name),
217-
tasks: tasksWithWrites(tasks, pendingWrites, taskStates),
226+
tasks: tasksWithWrites(tasks, pendingWrites, taskStates, outputKeys),
218227
parentConfig: parentConfig ? formatConfig(parentConfig) : undefined,
219228
},
220229
};
@@ -223,36 +232,64 @@ export function* mapDebugCheckpoint<
223232
export function tasksWithWrites<N extends PropertyKey, C extends PropertyKey>(
224233
tasks: PregelTaskDescription[] | readonly PregelExecutableTask<N, C>[],
225234
pendingWrites: CheckpointPendingWrite[],
226-
states?: Record<string, RunnableConfig | StateSnapshot>
235+
states: Record<string, RunnableConfig | StateSnapshot> | undefined,
236+
outputKeys: ChannelKey[] | ChannelKey
227237
): PregelTaskDescription[] {
228238
return tasks.map((task): PregelTaskDescription => {
229239
const error = pendingWrites.find(
230240
([id, n]) => id === task.id && n === ERROR
231241
)?.[2];
232242

233243
const interrupts = pendingWrites
234-
.filter(([id, n]) => {
235-
return id === task.id && n === INTERRUPT;
236-
})
237-
.map(([, , v]) => {
238-
return v;
239-
}) as Interrupt[];
244+
.filter(([id, n]) => id === task.id && n === INTERRUPT)
245+
.map(([, , v]) => v) as Interrupt[];
246+
247+
const result = (() => {
248+
if (error || interrupts.length || !pendingWrites.length) return undefined;
249+
250+
const idx = pendingWrites.findIndex(
251+
([tid, n]) => tid === task.id && n === RETURN
252+
);
253+
254+
if (idx >= 0) return pendingWrites[idx][2];
255+
256+
if (typeof outputKeys === "string") {
257+
return pendingWrites.find(
258+
([tid, n]) => tid === task.id && n === outputKeys
259+
)?.[2];
260+
}
261+
262+
if (Array.isArray(outputKeys)) {
263+
const results = pendingWrites
264+
.filter(([tid, n]) => tid === task.id && outputKeys.includes(n))
265+
.map(([, n, v]) => [n, v]);
266+
267+
if (!results.length) return undefined;
268+
return Object.fromEntries(results);
269+
}
270+
271+
return undefined;
272+
})();
273+
240274
if (error) {
241275
return {
242276
id: task.id,
243277
name: task.name as string,
244278
path: task.path,
245279
error,
246280
interrupts,
281+
result,
247282
};
248283
}
284+
249285
const taskState = states?.[task.id];
250286
return {
251287
id: task.id,
252288
name: task.name as string,
253289
path: task.path,
254290
interrupts,
255291
...(taskState !== undefined ? { state: taskState } : {}),
292+
result,
256293
};
257294
});
258295
}

libs/langgraph/src/pregel/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,12 @@ export class Pregel<
885885
this.streamChannelsAsIs as string | string[]
886886
),
887887
next: nextList,
888-
tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? [], taskStates),
888+
tasks: tasksWithWrites(
889+
nextTasks,
890+
saved?.pendingWrites ?? [],
891+
taskStates,
892+
this.streamChannelsAsIs
893+
),
889894
metadata,
890895
config: patchCheckpointMap(saved.config, saved.metadata),
891896
createdAt: saved.checkpoint.ts,

libs/langgraph/src/pregel/loop.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,8 @@ export class PregelLoop {
764764
this.checkpointMetadata,
765765
Object.values(this.tasks),
766766
this.checkpointPendingWrites,
767-
this.prevCheckpointConfig
767+
this.prevCheckpointConfig,
768+
this.outputKeys
768769
),
769770
"debug"
770771
)

libs/langgraph/src/pregel/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ export interface PregelTaskDescription {
403403
readonly interrupts: Interrupt[];
404404
readonly state?: LangGraphRunnableConfig | StateSnapshot;
405405
readonly path?: TaskPath;
406+
readonly result?: unknown;
406407
}
407408

408409
interface CacheKey {

0 commit comments

Comments
 (0)