|
| 1 | +import { StateGraph, MemorySaver, START, END } from "@langchain/langgraph"; |
| 2 | +import { DreamWriterStateAnnotation } from "./state"; |
| 3 | +import { intentParserNode } from "./nodes/intent-parser"; |
| 4 | +import { storyArchitectNode } from "./nodes/story-architect"; |
| 5 | +import { writerNode } from "./nodes/writer"; |
| 6 | +import { qualityGuardNode } from "./nodes/quality-guard"; |
| 7 | +import { deliveryNode } from "./nodes/delivery"; |
| 8 | +import type { DreamWriterState } from "./state"; |
| 9 | + |
| 10 | +const MAX_REVISIONS = 2; |
| 11 | + |
| 12 | +function routeAfterQualityCheck(state: DreamWriterState): string { |
| 13 | + const report = state.qualityReport; |
| 14 | + if (!report) return "delivery_node"; |
| 15 | + if (!report.passesThreshold && state.revisionCount < MAX_REVISIONS) { |
| 16 | + console.log(`[DreamWriter] Quality score ${report.overallScore}/10, revision ${state.revisionCount + 1}/${MAX_REVISIONS}`); |
| 17 | + return "writer_node"; |
| 18 | + } |
| 19 | + return "delivery_node"; |
| 20 | +} |
| 21 | + |
| 22 | +async function revisionCounterNode(state: DreamWriterState): Promise<Partial<DreamWriterState>> { |
| 23 | + return { revisionCount: state.revisionCount + 1, stage: "revising", logs: [{ message: `正在根据反馈修改第 ${state.revisionCount + 1} 版...`, done: true }] }; |
| 24 | +} |
| 25 | + |
| 26 | +const workflow = new StateGraph(DreamWriterStateAnnotation) |
| 27 | + .addNode("intent_parser_node", intentParserNode) |
| 28 | + .addNode("story_architect_node", storyArchitectNode) |
| 29 | + .addNode("writer_node", writerNode) |
| 30 | + .addNode("quality_guard_node", qualityGuardNode) |
| 31 | + .addNode("revision_counter_node", revisionCounterNode) |
| 32 | + .addNode("delivery_node", deliveryNode) |
| 33 | + .addEdge(START, "intent_parser_node") |
| 34 | + .addEdge("intent_parser_node", "story_architect_node") |
| 35 | + .addEdge("story_architect_node", "writer_node") |
| 36 | + .addEdge("writer_node", "quality_guard_node") |
| 37 | + .addConditionalEdges("quality_guard_node", routeAfterQualityCheck, { writer_node: "revision_counter_node", delivery_node: "delivery_node" }) |
| 38 | + .addEdge("revision_counter_node", "writer_node") |
| 39 | + .addEdge("delivery_node", END); |
| 40 | + |
| 41 | +const memory = new MemorySaver(); |
| 42 | +export const graph = workflow.compile({ checkpointer: memory }); |
0 commit comments