Skip to content

Commit

Permalink
Merge pull request langchain-ai#1 from langchain-ai/brace/quick-actions
Browse files Browse the repository at this point in the history
feat: Add state fields for quick actions
  • Loading branch information
bracesproul authored Oct 5, 2024
2 parents f43a2cc + 81cee2e commit e88354e
Show file tree
Hide file tree
Showing 10 changed files with 466 additions and 72 deletions.
276 changes: 227 additions & 49 deletions src/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@ import {
import { ChatOpenAI } from "@langchain/openai";
import { z } from "zod";
import {
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
CHANGE_ARTIFACT_LANGUAGE_PROMPT,
CHANGE_ARTIFACT_LENGTH_PROMPT,
CHANGE_ARTIFACT_READING_LEVEL_PROMPT,
CHANGE_ARTIFACT_TO_PIRATE_PROMPT,
FOLLOWUP_ARTIFACT_PROMPT,
NEW_ARTIFACT_PROMPT,
ROUTE_QUERY_PROMPT,
UPDATE_ENTIRE_ARTIFACT_PROMPT,
UPDATE_HIGHLIGHTED_ARTIFACT_PROMPT,
} from "./prompts";
import { Artifact } from "../types";
import {
Artifact,
ArtifactLengthOptions,
LanguageOptions,
ReadingLevelOptions,
} from "../types";
import { v4 as uuidv4 } from "uuid";

interface Highlight {
Expand Down Expand Up @@ -51,6 +61,26 @@ const GraphAnnotation = Annotation.Root({
reducer: (_state, update) => update,
default: () => [],
}),
/**
* The next node to route to. Only used for the first routing node/conditional edge.
*/
next: Annotation<string | undefined>,
/**
* The language to translate the artifact to.
*/
language: Annotation<LanguageOptions | undefined>,
/**
* The length of the artifact to regenerate to.
*/
artifactLength: Annotation<ArtifactLengthOptions | undefined>,
/**
* Whether or not to regenerate with emojis.
*/
regenerateWithEmojis: Annotation<boolean | undefined>,
/**
* The reading level to adjust the artifact to.
*/
readingLevel: Annotation<ReadingLevelOptions | undefined>,
});

type GraphReturnType = Partial<typeof GraphAnnotation.State>;
Expand Down Expand Up @@ -101,8 +131,6 @@ The user has generated artifacts in the past. Use the following artifacts as con

/**
* Update an existing artifact based on the user's query.
*
* TODO: break into two nodes, one for updating and one for regenerating.
*/
const updateArtifact = async (
state: typeof GraphAnnotation.State
Expand All @@ -120,38 +148,9 @@ const updateArtifact = async (
}

if (!state.highlighted) {
// No highlighted text is present, so we need to update the entire artifact.
const formattedPrompt = UPDATE_ENTIRE_ARTIFACT_PROMPT.replace(
"{artifactContent}",
selectedArtifact.content
);

const recentHumanMessage = state.messages.findLast(
(message) => message._getType() === "human"
throw new Error(
"Can not partially regenerate an artifact without a highlight"
);
if (!recentHumanMessage) {
throw new Error("No recent human message found");
}
const newArtifact = await smallModel.invoke([
{ role: "system", content: formattedPrompt },
recentHumanMessage,
]);

// Remove the original artifact message from the history.
const newArtifacts: Artifact[] = [
...state.artifacts.filter(
(artifact) => artifact.id !== selectedArtifact.id
),
{
...selectedArtifact,
content: newArtifact.content as string,
},
];

return {
artifacts: newArtifacts,
highlighted: undefined,
};
}

// Highlighted text is present, so we need to update the highlighted text.
Expand Down Expand Up @@ -215,6 +214,160 @@ const updateArtifact = async (
};
};

const rewriteArtifact = async (
state: typeof GraphAnnotation.State
): Promise<GraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const selectedArtifact = state.artifacts.find(
(artifact) => artifact.id === state.selectedArtifactId
);
if (!selectedArtifact) {
throw new Error("No artifact found with the selected ID");
}

const formattedPrompt = UPDATE_ENTIRE_ARTIFACT_PROMPT.replace(
"{artifactContent}",
selectedArtifact.content
);

const recentHumanMessage = state.messages.findLast(
(message) => message._getType() === "human"
);
if (!recentHumanMessage) {
throw new Error("No recent human message found");
}
const newArtifact = await smallModel.invoke([
{ role: "system", content: formattedPrompt },
recentHumanMessage,
]);

// Remove the original artifact message from the history.
const newArtifacts: Artifact[] = [
...state.artifacts.filter(
(artifact) => artifact.id !== selectedArtifact.id
),
{
...selectedArtifact,
content: newArtifact.content as string,
},
];

return {
artifacts: newArtifacts,
selectedArtifactId: undefined,
highlighted: undefined,
language: undefined,
artifactLength: undefined,
regenerateWithEmojis: undefined,
readingLevel: undefined,
};
};

const rewriteArtifactTheme = async (
state: typeof GraphAnnotation.State
): Promise<GraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o",
temperature: 0.5,
});

const selectedArtifact = state.artifacts.find(
(artifact) => artifact.id === state.selectedArtifactId
);
if (!selectedArtifact) {
throw new Error("No artifact found with the selected ID");
}

let formattedPrompt = "";
if (state.language) {
formattedPrompt = CHANGE_ARTIFACT_LANGUAGE_PROMPT.replace(
"{newLanguage}",
state.language
).replace("{artifactContent}", selectedArtifact.content);
} else if (state.readingLevel && state.readingLevel !== "pirate") {
let newReadingLevel = "";
switch (state.readingLevel) {
case "child":
newReadingLevel = "elementary school student";
break;
case "teenager":
newReadingLevel = "high school student";
break;
case "college":
newReadingLevel = "college student";
break;
case "phd":
newReadingLevel = "PhD student";
break;
}
formattedPrompt = CHANGE_ARTIFACT_READING_LEVEL_PROMPT.replace(
"{newReadingLevel}",
""
).replace("{artifactContent}", selectedArtifact.content);
} else if (state.readingLevel && state.readingLevel === "pirate") {
formattedPrompt = CHANGE_ARTIFACT_TO_PIRATE_PROMPT.replace(
"{artifactContent}",
selectedArtifact.content
);
} else if (state.artifactLength) {
let newLength = "";
switch (state.artifactLength) {
case "shortest":
newLength = "much shorter than it currently is";
break;
case "short":
newLength = "slightly shorter than it currently is";
break;
case "long":
newLength = "slightly longer than it currently is";
break;
case "longest":
newLength = "much longer than it currently is";
break;
}
formattedPrompt = CHANGE_ARTIFACT_LENGTH_PROMPT.replace(
"{newLength}",
newLength
).replace("{artifactContent}", selectedArtifact.content);
} else if (state.regenerateWithEmojis) {
formattedPrompt = ADD_EMOJIS_TO_ARTIFACT_PROMPT.replace(
"{artifactContent}",
selectedArtifact.content
);
} else {
throw new Error("No theme selected");
}

const newArtifact = await smallModel.invoke([
{ role: "user", content: formattedPrompt },
]);

// Remove the original artifact message from the history.
const newArtifacts: Artifact[] = [
...state.artifacts.filter(
(artifact) => artifact.id !== selectedArtifact.id
),
{
...selectedArtifact,
content: newArtifact.content as string,
},
];

return {
artifacts: newArtifacts,
selectedArtifactId: undefined,
highlighted: undefined,
language: undefined,
artifactLength: undefined,
regenerateWithEmojis: undefined,
readingLevel: undefined,
};
};

/**
* Generate a new artifact based on the user's query.
*/
Expand Down Expand Up @@ -289,12 +442,23 @@ const generateFollowup = async (
/**
* Routes to the proper node in the graph based on the user's query.
*/
const routeQuery = async (state: typeof GraphAnnotation.State) => {
const generatePath = async (state: typeof GraphAnnotation.State) => {
if (state.highlighted) {
return new Send("updateArtifact", {
...state,
return {
next: "updateArtifact",
selectedArtifactId: state.highlighted.id,
});
};
}

if (
state.language ||
state.artifactLength ||
state.regenerateWithEmojis ||
state.readingLevel
) {
return {
next: "rewriteArtifactTheme",
};
}

// Call model and decide if we need to respond to a users query, or generate a new artifact
Expand All @@ -312,8 +476,6 @@ const routeQuery = async (state: typeof GraphAnnotation.State) => {
}).withStructuredOutput(
z.object({
route: z.enum(["updateArtifact", "respondToQuery", "generateArtifact"]),
// TODO: HOW TO PASS THIS THROUGH TO NEXT NODE.
// maybe `send`?
artifactId: z
.string()
.optional()
Expand All @@ -332,27 +494,43 @@ const routeQuery = async (state: typeof GraphAnnotation.State) => {
]);

if (result.route === "updateArtifact") {
return new Send("updateArtifact", {
...state,
return {
// Only route to the `updateArtifact` node if highlighted text is present.
// Otherwise we need to rewrite the entire artifact.
next: "rewriteArtifact",
selectedArtifactId: result.artifactId,
});
};
} else {
return result.route;
return {
next: result.route,
};
}
};

const routeNode = (state: typeof GraphAnnotation.State) => {
if (!state.next) {
throw new Error("'next' state field not set.");
}

return new Send(state.next, {
...state,
});
};

const builder = new StateGraph(GraphAnnotation)
.addNode("generatePath", generatePath)
.addEdge(START, "generatePath")
.addConditionalEdges("generatePath", routeNode)
.addNode("respondToQuery", respondToQuery)
.addNode("rewriteArtifact", rewriteArtifact)
.addNode("rewriteArtifactTheme", rewriteArtifactTheme)
.addNode("updateArtifact", updateArtifact)
.addNode("generateArtifact", generateArtifact)
.addNode("generateFollowup", generateFollowup)
.addConditionalEdges(START, routeQuery, [
"updateArtifact",
"respondToQuery",
"generateArtifact",
])
.addEdge("generateArtifact", "generateFollowup")
.addEdge("updateArtifact", "generateFollowup")
.addEdge("rewriteArtifact", "generateFollowup")
.addEdge("rewriteArtifactTheme", "generateFollowup")
.addEdge("respondToQuery", END)
.addEdge("generateFollowup", END);

Expand Down
Loading

0 comments on commit e88354e

Please sign in to comment.