Skip to content

Commit

Permalink
Merge pull request langchain-ai#136 from ahmad2b/ahmad2b/customizable…
Browse files Browse the repository at this point in the history
…-model-selection-v2

Fixes langchain-ai#63: Implement Customizable Model Selection (Updated Implementation)
  • Loading branch information
bracesproul authored Oct 25, 2024
2 parents a3b3902 + 318b00f commit 6f4014c
Show file tree
Hide file tree
Showing 25 changed files with 458 additions and 166 deletions.
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"@codemirror/lang-sql": "^6.8.0",
"@codemirror/lang-xml": "^6.1.0",
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.14",
"@langchain/langgraph": "^0.2.18",
"@langchain/langgraph-sdk": "^0.0.17",
Expand Down Expand Up @@ -62,6 +63,7 @@
"dotenv": "^16.4.5",
"framer-motion": "^11.11.9",
"js-cookie": "^3.0.5",
"langchain": "^0.3.4",
"langsmith": "^0.1.61",
"lodash": "^4.17.21",
"lucide-react": "^0.441.0",
Expand Down
34 changes: 19 additions & 15 deletions src/agent/open-canvas/nodes/customAction.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import {
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX,
CUSTOM_QUICK_ACTION_CONVERSATION_CONTEXT,
REFLECTIONS_QUICK_ACTION_PROMPT,
} from "../prompts";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import { BaseMessage } from "@langchain/core/messages";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import {
ArtifactCodeV3,
ArtifactMarkdownV3,
ArtifactV3,
CustomQuickAction,
Reflections,
} from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { BaseMessage } from "@langchain/core/messages";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
} from "../../utils";
import {
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX,
CUSTOM_QUICK_ACTION_CONVERSATION_CONTEXT,
REFLECTIONS_QUICK_ACTION_PROMPT,
} from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

const formatMessages = (messages: BaseMessage[]): string =>
messages
Expand All @@ -35,8 +39,8 @@ export const customAction = async (
throw new Error("No custom quick action ID found.");
}

const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
});

Expand Down
18 changes: 11 additions & 7 deletions src/agent/open-canvas/nodes/generateFollowup.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import { Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
} from "../../utils";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

/**
* Generate a followup message after generating or updating an artifact.
Expand All @@ -14,8 +18,8 @@ export const generateFollowup = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
maxTokens: 250,
});
Expand Down
15 changes: 8 additions & 7 deletions src/agent/open-canvas/nodes/respondToQuery.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatArtifactContentWithTemplate,
formatReflections,
getModelNameFromConfig,
} from "../../utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { Reflections } from "../../../types";
import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

/**
* Generate responses to questions. Does not generate artifacts.
Expand All @@ -17,8 +18,8 @@ export const respondToQuery = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
});

Expand Down
80 changes: 43 additions & 37 deletions src/agent/open-canvas/nodes/rewriteArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ensureStoreInConfig,
formatArtifactContent,
formatReflections,
getModelNameFromConfig,
} from "../../utils";
import {
ArtifactCodeV3,
Expand All @@ -24,13 +25,50 @@ import {
isArtifactCodeContent,
isArtifactMarkdownContent,
} from "../../../lib/artifact_content_types";
import { initChatModel } from "langchain/chat_models/universal";

export const rewriteArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
const optionallyUpdateArtifactMetaSchema = z.object({
type: z
.enum(["text", "code"])
.describe("The type of the artifact content."),
title: z
.string()
.optional()
.describe(
"The new title to give the artifact. ONLY update this if the user is making a request which changes the subject/topic of the artifact."
),
programmingLanguage: z
.enum(
PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [
string,
...string[],
]
)
.optional()
.describe(
"The programming language of the code artifact. ONLY update this if the user is making a request which changes the programming language of the code artifact, or is asking for a code artifact to be generated."
),
});
// TODO: Once Anthropic tool call streaming is supported, use the custom model here.
const toolCallingModel = new ChatOpenAI({
model: "gpt-4o-mini",
temperature: 0,
})
.bindTools([
{
name: "optionallyUpdateArtifactMeta",
schema: optionallyUpdateArtifactMetaSchema,
description: "Update the artifact meta information, if necessary.",
},
])
.withConfig({ runName: "optionally_update_artifact_meta" });

const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
});

Expand Down Expand Up @@ -65,43 +103,11 @@ export const rewriteArtifact = async (
if (!recentHumanMessage) {
throw new Error("No recent human message found");
}
const optionallyUpdateArtifactMetaSchema = z.object({
type: z
.enum(["text", "code"])
.describe("The type of the artifact content."),
title: z
.string()
.optional()
.describe(
"The new title to give the artifact. ONLY update this if the user is making a request which changes the subject/topic of the artifact."
),
programmingLanguage: z
.enum(
PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [
string,
...string[],
]
)
.optional()
.describe(
"The programming language of the code artifact. ONLY update this if the user is making a request which changes the programming language of the code artifact, or is asking for a code artifact to be generated."
),
});
const optionallyUpdateModelWithTool = smallModel
.bindTools([
{
name: "optionallyUpdateArtifactMeta",
schema: optionallyUpdateArtifactMetaSchema,
description: "Update the artifact meta information, if necessary.",
},
])
.withConfig({ runName: "optionally_update_artifact_meta" });

const optionallyUpdateArtifactResponse =
await optionallyUpdateModelWithTool.invoke([
{ role: "system", content: optionallyUpdateArtifactMetaPrompt },
recentHumanMessage,
]);
const optionallyUpdateArtifactResponse = await toolCallingModel.invoke([
{ role: "system", content: optionallyUpdateArtifactMetaPrompt },
recentHumanMessage,
]);
const artifactMetaToolCall = optionallyUpdateArtifactResponse.tool_calls?.[0];
const artifactType = artifactMetaToolCall?.args?.type;
const isNewType = artifactType !== currentArtifactContent.type;
Expand Down
22 changes: 13 additions & 9 deletions src/agent/open-canvas/nodes/rewriteArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import { ArtifactV3, Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
} from "../../utils";
import {
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
CHANGE_ARTIFACT_LANGUAGE_PROMPT,
CHANGE_ARTIFACT_LENGTH_PROMPT,
CHANGE_ARTIFACT_READING_LEVEL_PROMPT,
CHANGE_ARTIFACT_TO_PIRATE_PROMPT,
} from "../prompts";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import { ArtifactV3, Reflections } from "../../../types";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

export const rewriteArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
});

Expand Down
19 changes: 11 additions & 8 deletions src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import { ChatOpenAI } from "@langchain/openai";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
import { getModelNameFromConfig } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { isArtifactCodeContent } from "../../../lib/artifact_content_types";
import { ArtifactCodeV3, ArtifactV3 } from "../../../types";
import {
ADD_COMMENTS_TO_CODE_ARTIFACT_PROMPT,
ADD_LOGS_TO_CODE_ARTIFACT_PROMPT,
FIX_BUGS_CODE_ARTIFACT_PROMPT,
PORT_LANGUAGE_CODE_ARTIFACT_PROMPT,
} from "../prompts";
import { ArtifactCodeV3, ArtifactV3 } from "../../../types";
import { isArtifactCodeContent } from "../../../lib/artifact_content_types";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

export const rewriteCodeArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const smallModel = new ChatOpenAI({
model: "gpt-4o-mini",
const modelName = getModelNameFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
});

Expand Down
12 changes: 11 additions & 1 deletion src/agent/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { isArtifactCodeContent } from "@/lib/artifact_content_types";
import { BaseStore, LangGraphRunnableConfig } from "@langchain/langgraph";
import { ArtifactCodeV3, ArtifactMarkdownV3, Reflections } from "../types";
import { isArtifactCodeContent } from "@/lib/artifact_content_types";

export const formatReflections = (
reflections: Reflections,
Expand Down Expand Up @@ -79,3 +79,13 @@ export const formatArtifactContentWithTemplate = (
formatArtifactContent(content, shortenContent)
);
};

export const getModelNameFromConfig = (
config: LangGraphRunnableConfig
): string => {
const customModelName = config.configurable?.customModelName as string;
if (!customModelName) {
throw new Error("Model name is missing in config.");
}
return customModelName;
};
Loading

0 comments on commit 6f4014c

Please sign in to comment.