Skip to content

Commit d519b00

Browse files
committed
feat: address PR feedback - use initChatModel and implement configurable model selection
1 parent 2221239 commit d519b00

15 files changed

+332
-102
lines changed

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"dotenv": "^16.4.5",
5858
"framer-motion": "^11.11.9",
5959
"js-cookie": "^3.0.5",
60+
"langchain": "^0.3.4",
6061
"langsmith": "^0.1.61",
6162
"lodash": "^4.17.21",
6263
"lucide-react": "^0.441.0",

src/agent/open-canvas/nodes/customAction.ts

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
import { ChatOpenAI } from "@langchain/openai";
2-
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
3-
import {
4-
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
5-
CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX,
6-
CUSTOM_QUICK_ACTION_CONVERSATION_CONTEXT,
7-
REFLECTIONS_QUICK_ACTION_PROMPT,
8-
} from "../prompts";
9-
import { ensureStoreInConfig, formatReflections } from "../../utils";
1+
import { BaseMessage } from "@langchain/core/messages";
2+
import { LangGraphRunnableConfig } from "@langchain/langgraph";
3+
import { initChatModel } from "langchain/chat_models/universal";
4+
import { getArtifactContent } from "../../../hooks/use-graph/utils";
5+
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
106
import {
117
ArtifactCodeV3,
128
ArtifactMarkdownV3,
139
ArtifactV3,
1410
CustomQuickAction,
1511
Reflections,
1612
} from "../../../types";
17-
import { LangGraphRunnableConfig } from "@langchain/langgraph";
18-
import { BaseMessage } from "@langchain/core/messages";
19-
import { getArtifactContent } from "../../../hooks/use-graph/utils";
20-
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
13+
import {
14+
ensureStoreInConfig,
15+
formatReflections,
16+
getModelNameFromConfig,
17+
} from "../../utils";
18+
import {
19+
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
20+
CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX,
21+
CUSTOM_QUICK_ACTION_CONVERSATION_CONTEXT,
22+
REFLECTIONS_QUICK_ACTION_PROMPT,
23+
} from "../prompts";
24+
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
2125

2226
const formatMessages = (messages: BaseMessage[]): string =>
2327
messages
@@ -35,8 +39,8 @@ export const customAction = async (
3539
throw new Error("No custom quick action ID found.");
3640
}
3741

38-
const smallModel = new ChatOpenAI({
39-
model: "gpt-4o-mini",
42+
const modelName = getModelNameFromConfig(config);
43+
const smallModel = await initChatModel(modelName, {
4044
temperature: 0.5,
4145
});
4246

src/agent/open-canvas/nodes/generateFollowup.ts

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
import { ChatOpenAI } from "@langchain/openai";
2-
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
3-
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
4-
import { ensureStoreInConfig, formatReflections } from "../../utils";
5-
import { Reflections } from "../../../types";
61
import { LangGraphRunnableConfig } from "@langchain/langgraph";
2+
import { initChatModel } from "langchain/chat_models/universal";
73
import { getArtifactContent } from "../../../hooks/use-graph/utils";
84
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
5+
import { Reflections } from "../../../types";
6+
import {
7+
ensureStoreInConfig,
8+
formatReflections,
9+
getModelNameFromConfig,
10+
} from "../../utils";
11+
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
12+
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
913

1014
/**
1115
* Generate a followup message after generating or updating an artifact.
@@ -14,8 +18,8 @@ export const generateFollowup = async (
1418
state: typeof OpenCanvasGraphAnnotation.State,
1519
config: LangGraphRunnableConfig
1620
): Promise<OpenCanvasGraphReturnType> => {
17-
const smallModel = new ChatOpenAI({
18-
model: "gpt-4o-mini",
21+
const modelName = getModelNameFromConfig(config);
22+
const smallModel = await initChatModel(modelName, {
1923
temperature: 0.5,
2024
maxTokens: 250,
2125
});

src/agent/open-canvas/nodes/respondToQuery.ts

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import { ChatOpenAI } from "@langchain/openai";
2-
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1+
import { LangGraphRunnableConfig } from "@langchain/langgraph";
2+
import { initChatModel } from "langchain/chat_models/universal";
3+
import { getArtifactContent } from "../../../hooks/use-graph/utils";
4+
import { Reflections } from "../../../types";
35
import {
46
ensureStoreInConfig,
57
formatArtifactContentWithTemplate,
68
formatReflections,
9+
getModelNameFromConfig,
710
} from "../../utils";
8-
import { LangGraphRunnableConfig } from "@langchain/langgraph";
9-
import { Reflections } from "../../../types";
1011
import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts";
11-
import { getArtifactContent } from "../../../hooks/use-graph/utils";
12+
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1213

1314
/**
1415
* Generate responses to questions. Does not generate artifacts.
@@ -17,8 +18,8 @@ export const respondToQuery = async (
1718
state: typeof OpenCanvasGraphAnnotation.State,
1819
config: LangGraphRunnableConfig
1920
): Promise<OpenCanvasGraphReturnType> => {
20-
const smallModel = new ChatOpenAI({
21-
model: "gpt-4o-mini",
21+
const modelName = getModelNameFromConfig(config);
22+
const smallModel = await initChatModel(modelName, {
2223
temperature: 0.5,
2324
});
2425

src/agent/open-canvas/nodes/rewriteArtifactTheme.ts

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
import { ChatOpenAI } from "@langchain/openai";
2-
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1+
import { LangGraphRunnableConfig } from "@langchain/langgraph";
2+
import { initChatModel } from "langchain/chat_models/universal";
3+
import { getArtifactContent } from "../../../hooks/use-graph/utils";
4+
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
5+
import { ArtifactV3, Reflections } from "../../../types";
6+
import {
7+
ensureStoreInConfig,
8+
formatReflections,
9+
getModelNameFromConfig,
10+
} from "../../utils";
311
import {
412
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
513
CHANGE_ARTIFACT_LANGUAGE_PROMPT,
614
CHANGE_ARTIFACT_LENGTH_PROMPT,
715
CHANGE_ARTIFACT_READING_LEVEL_PROMPT,
816
CHANGE_ARTIFACT_TO_PIRATE_PROMPT,
917
} from "../prompts";
10-
import { ensureStoreInConfig, formatReflections } from "../../utils";
11-
import { ArtifactV3, Reflections } from "../../../types";
12-
import { LangGraphRunnableConfig } from "@langchain/langgraph";
13-
import { getArtifactContent } from "../../../hooks/use-graph/utils";
14-
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
18+
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1519

1620
export const rewriteArtifactTheme = async (
1721
state: typeof OpenCanvasGraphAnnotation.State,
1822
config: LangGraphRunnableConfig
1923
): Promise<OpenCanvasGraphReturnType> => {
20-
const smallModel = new ChatOpenAI({
21-
model: "gpt-4o-mini",
24+
const modelName = getModelNameFromConfig(config);
25+
const smallModel = await initChatModel(modelName, {
2226
temperature: 0.5,
2327
});
2428

src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
import { ChatOpenAI } from "@langchain/openai";
2-
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1+
import { getModelNameFromConfig } from "@/agent/utils";
2+
import { LangGraphRunnableConfig } from "@langchain/langgraph";
3+
import { initChatModel } from "langchain/chat_models/universal";
4+
import { getArtifactContent } from "../../../hooks/use-graph/utils";
5+
import { isArtifactCodeContent } from "../../../lib/artifact_content_types";
6+
import { ArtifactCodeV3, ArtifactV3 } from "../../../types";
37
import {
48
ADD_COMMENTS_TO_CODE_ARTIFACT_PROMPT,
59
ADD_LOGS_TO_CODE_ARTIFACT_PROMPT,
610
FIX_BUGS_CODE_ARTIFACT_PROMPT,
711
PORT_LANGUAGE_CODE_ARTIFACT_PROMPT,
812
} from "../prompts";
9-
import { ArtifactCodeV3, ArtifactV3 } from "../../../types";
10-
import { isArtifactCodeContent } from "../../../lib/artifact_content_types";
11-
import { getArtifactContent } from "../../../hooks/use-graph/utils";
13+
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
1214

1315
export const rewriteCodeArtifactTheme = async (
14-
state: typeof OpenCanvasGraphAnnotation.State
16+
state: typeof OpenCanvasGraphAnnotation.State,
17+
config: LangGraphRunnableConfig
1518
): Promise<OpenCanvasGraphReturnType> => {
16-
const smallModel = new ChatOpenAI({
17-
model: "gpt-4o-mini",
19+
const modelName = getModelNameFromConfig(config);
20+
const smallModel = await initChatModel(modelName, {
1821
temperature: 0.5,
1922
});
2023

src/agent/utils.ts

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import { isArtifactCodeContent } from "@/lib/artifact_content_types";
12
import { BaseStore, LangGraphRunnableConfig } from "@langchain/langgraph";
23
import { ArtifactCodeV3, ArtifactMarkdownV3, Reflections } from "../types";
3-
import { isArtifactCodeContent } from "@/lib/artifact_content_types";
44

55
export const formatReflections = (
66
reflections: Reflections,
@@ -79,3 +79,13 @@ export const formatArtifactContentWithTemplate = (
7979
formatArtifactContent(content, shortenContent)
8080
);
8181
};
82+
83+
export const getModelNameFromConfig = (
84+
config: LangGraphRunnableConfig
85+
): string => {
86+
const customModelName = config.metadata?.customModelName as string;
87+
if (!customModelName) {
88+
throw new Error("Model name is missing in config.");
89+
}
90+
return customModelName;
91+
};

src/components/Canvas.tsx

+14-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import { ArtifactRenderer } from "@/components/artifacts/ArtifactRenderer";
44
import { ContentComposerChatInterface } from "@/components/ContentComposer";
5-
import { useToast } from "@/hooks/use-toast";
5+
import { ALL_MODEL_NAMES } from "@/constants";
66
import { useGraph } from "@/hooks/use-graph/useGraph";
7+
import { useToast } from "@/hooks/use-toast";
78
import { useStore } from "@/hooks/useStore";
89
import { useThread } from "@/hooks/useThread";
910
import { getLanguageTemplate } from "@/lib/get_language_template";
@@ -35,6 +36,8 @@ export function Canvas(props: CanvasProps) {
3536
setThreadId,
3637
getOrCreateAssistant,
3738
clearThreadsWithNoValues,
39+
modelName,
40+
setModelName,
3841
} = useThread(props.user.id);
3942
const [chatStarted, setChatStarted] = useState(false);
4043
const [isEditing, setIsEditing] = useState(false);
@@ -59,6 +62,7 @@ export function Canvas(props: CanvasProps) {
5962
userId: props.user.id,
6063
threadId,
6164
assistantId,
65+
modelName,
6266
});
6367
const {
6468
reflections,
@@ -106,10 +110,12 @@ export function Canvas(props: CanvasProps) {
106110
getReflections();
107111
}, [assistantId]);
108112

109-
const createThreadWithChatStarted = async () => {
113+
const createThreadWithChatStarted = async (
114+
customModelName: ALL_MODEL_NAMES
115+
) => {
110116
setChatStarted(false);
111117
clearState();
112-
return createThread(props.user.id);
118+
return createThread(props.user.id, customModelName);
113119
};
114120

115121
const handleQuickStart = (
@@ -174,6 +180,9 @@ export function Canvas(props: CanvasProps) {
174180
// Chat should only be "started" if there are messages present
175181
if ((thread.values as Record<string, any>)?.messages?.length) {
176182
setChatStarted(true);
183+
setModelName(
184+
thread?.metadata?.customModelName as ALL_MODEL_NAMES
185+
);
177186
} else {
178187
setChatStarted(false);
179188
}
@@ -190,6 +199,8 @@ export function Canvas(props: CanvasProps) {
190199
setChatStarted={setChatStarted}
191200
showNewThreadButton={chatStarted}
192201
handleQuickStart={handleQuickStart}
202+
modelName={modelName}
203+
setModelName={setModelName}
193204
/>
194205
</div>
195206
{chatStarted && (

src/components/ContentComposer.tsx

+17-12
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
"use client";
22

3-
import React, { useState } from "react";
3+
import { ALL_MODEL_NAMES } from "@/constants";
4+
import { GraphInput } from "@/hooks/use-graph/useGraph";
5+
import { useToast } from "@/hooks/use-toast";
6+
import {
7+
convertLangchainMessages,
8+
convertToOpenAIFormat,
9+
} from "@/lib/convert_messages";
10+
import { ProgrammingLanguageOptions, Reflections } from "@/types";
411
import {
512
AppendMessage,
613
AssistantRuntimeProvider,
14+
useExternalMessageConverter,
715
useExternalStoreRuntime,
816
} from "@assistant-ui/react";
17+
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
18+
import { Thread as ThreadType } from "@langchain/langgraph-sdk";
19+
import React, { useState } from "react";
920
import { v4 as uuidv4 } from "uuid";
1021
import { Thread } from "./Primitives";
11-
import { useExternalMessageConverter } from "@assistant-ui/react";
12-
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
13-
import {
14-
convertLangchainMessages,
15-
convertToOpenAIFormat,
16-
} from "@/lib/convert_messages";
17-
import { GraphInput } from "@/hooks/use-graph/useGraph";
1822
import { Toaster } from "./ui/toaster";
19-
import { ProgrammingLanguageOptions, Reflections } from "@/types";
20-
import { Thread as ThreadType } from "@langchain/langgraph-sdk";
21-
import { useToast } from "@/hooks/use-toast";
2223

2324
export interface ContentComposerChatInterfaceProps {
2425
messages: BaseMessage[];
2526
streamMessage: (input: GraphInput) => Promise<void>;
2627
setMessages: React.Dispatch<React.SetStateAction<BaseMessage[]>>;
27-
createThread: () => Promise<ThreadType | undefined>;
28+
createThread: (modelName: ALL_MODEL_NAMES) => Promise<ThreadType | undefined>;
2829
setChatStarted: React.Dispatch<React.SetStateAction<boolean>>;
2930
showNewThreadButton: boolean;
3031
handleQuickStart: (
@@ -41,6 +42,8 @@ export interface ContentComposerChatInterfaceProps {
4142
deleteThread: (id: string) => Promise<void>;
4243
getUserThreads: (id: string) => Promise<void>;
4344
userId: string;
45+
modelName: ALL_MODEL_NAMES;
46+
setModelName: React.Dispatch<React.SetStateAction<ALL_MODEL_NAMES>>;
4447
}
4548

4649
export function ContentComposerChatInterface(
@@ -106,6 +109,8 @@ export function ContentComposerChatInterface(
106109
userThreads={props.userThreads}
107110
switchSelectedThread={props.switchSelectedThread}
108111
deleteThread={props.deleteThread}
112+
modelName={props.modelName}
113+
setModelName={props.setModelName}
109114
/>
110115
</AssistantRuntimeProvider>
111116
<Toaster />

0 commit comments

Comments
 (0)