Skip to content

Commit

Permalink
use reflections throughout graph
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Oct 10, 2024
1 parent b60252a commit 2f079d5
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 44 deletions.
6 changes: 5 additions & 1 deletion src/agent/open-canvas/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { rewriteArtifactTheme } from "./nodes/rewriteArtifactTheme";
import { updateArtifact } from "./nodes/updateArtifact";
import { respondToQuery } from "./nodes/respondToQuery";
import { rewriteCodeArtifactTheme } from "./nodes/rewriteCodeArtifactTheme";
import { reflect } from "../reflection";

const defaultInputs: Omit<
typeof OpenCanvasGraphAnnotation.State,
Expand Down Expand Up @@ -57,6 +58,7 @@ const builder = new StateGraph(OpenCanvasGraphAnnotation)
.addNode("generateArtifact", generateArtifact)
.addNode("generateFollowup", generateFollowup)
.addNode("cleanState", cleanState)
.addNode("reflect", reflect)
// Edges
.addEdge("generateArtifact", "generateFollowup")
.addEdge("updateArtifact", "generateFollowup")
Expand All @@ -65,7 +67,9 @@ const builder = new StateGraph(OpenCanvasGraphAnnotation)
.addEdge("rewriteCodeArtifactTheme", "generateFollowup")
// End edges
.addEdge("respondToQuery", "cleanState")
.addEdge("generateFollowup", "cleanState")
// Only reflect if an artifact was generated/updated.
.addEdge("generateFollowup", "reflect")
.addEdge("reflect", "cleanState")
.addEdge("cleanState", END);

export const graph = builder.compile();
29 changes: 26 additions & 3 deletions src/agent/open-canvas/nodes/generateArtifact.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { NEW_ARTIFACT_PROMPT } from "../prompts";
import { Artifact } from "../../../types";
import { Artifact, Reflections } from "../../../types";
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

/**
* Generate a new artifact based on the user's query.
*/
export const generateArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const modelWithArtifactTool = smallModel.bindTools(
[
{
Expand Down Expand Up @@ -46,8 +61,16 @@ export const generateArtifact = async (
{ tool_choice: "generate_artifact" }
);

const formattedNewArtifactPrompt = NEW_ARTIFACT_PROMPT.replace(
"{reflections}",
memoriesAsString
);

const response = await modelWithArtifactTool.invoke(
[{ role: "system", content: NEW_ARTIFACT_PROMPT }, ...state.messages],
[
{ role: "system", content: formattedNewArtifactPrompt },
...state.messages,
],
{ runName: "generate_artifact" }
);
const newArtifact: Artifact = {
Expand Down
21 changes: 19 additions & 2 deletions src/agent/open-canvas/nodes/generateFollowup.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

/**
* Generate a followup message after generating or updating an artifact.
*/
export const generateFollowup = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
temperature: 0.5,
maxTokens: 250,
});

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const recentArtifact = state.artifacts[state.artifacts.length - 1];
const formattedPrompt = FOLLOWUP_ARTIFACT_PROMPT.replace(
"{artifactContent}",
recentArtifact.content
);
).replace("{reflections}", memoriesAsString);

const response = await smallModel.invoke([
{ role: "user", content: formattedPrompt },
]);
Expand Down
35 changes: 35 additions & 0 deletions src/agent/open-canvas/nodes/reflect.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Client } from "@langchain/langgraph-sdk";
import { OpenCanvasGraphAnnotation } from "../state";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

export const reflect = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
) => {
const langGraphClient = new Client();

const selectedArtifact = state.selectedArtifactId
? state.artifacts.find((art) => art.id === state.selectedArtifactId)
: state.artifacts[state.artifacts.length - 1];
const reflectionInput = {
messages: state.messages,
artifact: selectedArtifact,
};
const reflectionConfig = {
configurable: {
// Ensure we pass in the current graph's assistant ID as this is
// how we fetch & store the memories.
assistant_id: config.configurable?.assistant_id,
},
};

const newThread = await langGraphClient.threads.create();
// Create a new reflection run, but do not `wait` for it to finish.
// Intended to be a background run.
await langGraphClient.runs.create(newThread.thread_id, "reflection", {
input: reflectionInput,
config: reflectionConfig,
});

return {};
};
30 changes: 25 additions & 5 deletions src/agent/open-canvas/nodes/respondToQuery.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { formatArtifacts } from "../utils";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { Reflections } from "../../../types";

/**
* Generate responses to questions. Does not generate artifacts.
*/
export const respondToQuery = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
Expand All @@ -17,14 +21,30 @@ export const respondToQuery = async (
The user has generated artifacts in the past. Use the following artifacts as context when responding to the users question.
You also have the following reflections on style guidelines and general memories/facts about the user to use when generating your response.
<reflections>
{reflections}
</reflections>
<artifacts>
{artifacts}
</artifacts>`;

const formattedPrompt = prompt.replace(
"{artifacts}",
formatArtifacts(state.artifacts)
);
const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const formattedPrompt = prompt
.replace("{artifacts}", formatArtifacts(state.artifacts))
.replace("{reflections}", memoriesAsString);

const response = await smallModel.invoke([
{ role: "system", content: formattedPrompt },
Expand Down
20 changes: 18 additions & 2 deletions src/agent/open-canvas/nodes/rewriteArtifact.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { UPDATE_ENTIRE_ARTIFACT_PROMPT } from "../prompts";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

export const rewriteArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const selectedArtifact = state.artifacts.find(
(artifact) => artifact.id === state.selectedArtifactId
);
Expand All @@ -20,7 +36,7 @@ export const rewriteArtifact = async (
const formattedPrompt = UPDATE_ENTIRE_ARTIFACT_PROMPT.replace(
"{artifactContent}",
selectedArtifact.content
);
).replace("{reflections}", memoriesAsString);

const recentHumanMessage = state.messages.findLast(
(message) => message._getType() === "human"
Expand Down
20 changes: 19 additions & 1 deletion src/agent/open-canvas/nodes/rewriteArtifactTheme.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,31 @@ import {
CHANGE_ARTIFACT_READING_LEVEL_PROMPT,
CHANGE_ARTIFACT_TO_PIRATE_PROMPT,
} from "../prompts";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

export const rewriteArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const selectedArtifact = state.artifacts.find(
(artifact) => artifact.id === state.selectedArtifactId
);
Expand Down Expand Up @@ -83,6 +99,8 @@ export const rewriteArtifactTheme = async (
throw new Error("No theme selected");
}

formattedPrompt = formattedPrompt.replace("{reflections}", memoriesAsString);

const newArtifactValues = await smallModel.invoke([
{ role: "user", content: formattedPrompt },
]);
Expand Down
21 changes: 19 additions & 2 deletions src/agent/open-canvas/nodes/updateArtifact.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { UPDATE_HIGHLIGHTED_ARTIFACT_PROMPT } from "../prompts";
import { ensureStoreInConfig, formatReflections } from "@/agent/utils";
import { Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

/**
* Update an existing artifact based on the user's query.
*/
export const updateArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
if (!assistantId) {
throw new Error("`assistant_id` not found in configurable");
}
const memoryNamespace = ["memories", assistantId];
const memoryKey = "reflection";
const memories = await store.get(memoryNamespace, memoryKey);
const memoriesAsString = memories?.value
? formatReflections(memories.value as Reflections)
: "No reflections found.";

const selectedArtifact = state.artifacts.find(
(artifact) => artifact.id === state.selectedArtifactId
);
Expand Down Expand Up @@ -51,7 +67,8 @@ export const updateArtifact = async (
highlightedText
)
.replace("{beforeHighlight}", beforeHighlight)
.replace("{afterHighlight}", afterHighlight);
.replace("{afterHighlight}", afterHighlight)
.replace("{reflections}", memoriesAsString);

const recentHumanMessage = state.messages.findLast(
(message) => message._getType() === "human"
Expand Down
Loading

0 comments on commit 2f079d5

Please sign in to comment.