diff --git a/.env.example b/.env.example index 0f02317a..d8e99e24 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,7 @@ NEXT_PUBLIC_FIREWORKS_ENABLED=true NEXT_PUBLIC_GEMINI_ENABLED=false NEXT_PUBLIC_ANTHROPIC_ENABLED=true NEXT_PUBLIC_OPENAI_ENABLED=true +NEXT_PUBLIC_AZURE_ENABLED=false # LangGraph Deployment, or local development server via LangGraph Studio. # If running locally, this URL should be set in the `constants.ts` file. @@ -25,3 +26,12 @@ NEXT_PUBLIC_OPENAI_ENABLED=true # Public keys NEXT_PUBLIC_SUPABASE_URL= NEXT_PUBLIC_SUPABASE_ANON_KEY= + +# Azure OpenAI Configuration +# ENSURE THEY ARE PREFIXED WITH AN UNDERSCORE. +_AZURE_OPENAI_API_KEY=your-azure-openai-api-key +_AZURE_OPENAI_API_INSTANCE_NAME=your-instance-name +_AZURE_OPENAI_API_DEPLOYMENT_NAME=your-deployment-name +_AZURE_OPENAI_API_VERSION=2024-08-01-preview +# Optional: Azure OpenAI Base Path (if using a different domain) +# _AZURE_OPENAI_API_BASE_PATH=https://your-custom-domain.com/openai/deployments diff --git a/package.json b/package.json index 2188acd7..dbed37ba 100644 --- a/package.json +++ b/package.json @@ -69,8 +69,9 @@ "langsmith": "^0.1.61", "lodash": "^4.17.21", "lucide-react": "^0.441.0", - "next": "14.2.7", + "next": "14.2.10", "react": "^18", + "react-colorful": "^5.6.1", "react-dom": "^18", "react-icons": "^5.3.0", "react-json-view": "^1.21.3", @@ -97,7 +98,7 @@ "@typescript-eslint/eslint-plugin": "^8.12.2", "@typescript-eslint/parser": "^8.8.1", "eslint": "^8", - "eslint-config-next": "14.2.7", + "eslint-config-next": "14.2.10", "postcss": "^8", "prettier": "^3.3.3", "tailwind-scrollbar": "^3.1.0", diff --git a/src/agent/open-canvas/index.ts b/src/agent/open-canvas/index.ts index 2dd9d194..3947ad68 100644 --- a/src/agent/open-canvas/index.ts +++ b/src/agent/open-canvas/index.ts @@ -1,11 +1,11 @@ import { END, Send, START, StateGraph } from "@langchain/langgraph"; import { DEFAULT_INPUTS } from "../../constants"; import { customAction } from "./nodes/customAction"; -import { generateArtifact } from "./nodes/generateArtifact"; +import { generateArtifact } from "./nodes/generate-artifact"; import { generateFollowup } from "./nodes/generateFollowup"; import { generatePath } from "./nodes/generatePath"; import { reflectNode } from "./nodes/reflect"; -import { rewriteArtifact } from "./nodes/rewriteArtifact"; +import { rewriteArtifact } from "./nodes/rewrite-artifact"; import { rewriteArtifactTheme } from "./nodes/rewriteArtifactTheme"; import { updateArtifact } from "./nodes/updateArtifact"; import { replyToGeneralInput } from "./nodes/replyToGeneralInput"; diff --git a/src/agent/open-canvas/nodes/customAction.ts b/src/agent/open-canvas/nodes/customAction.ts index 052961d1..31e2a610 100644 --- a/src/agent/open-canvas/nodes/customAction.ts +++ b/src/agent/open-canvas/nodes/customAction.ts @@ -1,6 +1,6 @@ import { BaseMessage } from "@langchain/core/messages"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { @@ -10,11 +10,7 @@ import { CustomQuickAction, Reflections, } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT, CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX, @@ -39,11 +35,8 @@ export const customAction = async ( throw new Error("No custom quick action ID found."); } - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const smallModel = await getModelFromConfig(config, { temperature: 0.5, - modelProvider, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/generate-artifact/index.ts b/src/agent/open-canvas/nodes/generate-artifact/index.ts new file mode 100644 index 00000000..4a477ec2 --- /dev/null +++ b/src/agent/open-canvas/nodes/generate-artifact/index.ts @@ -0,0 +1,63 @@ +import { + OpenCanvasGraphAnnotation, + OpenCanvasGraphReturnType, +} from "../../state"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; +import { + getFormattedReflections, + getModelFromConfig, + getModelConfig, + optionallyGetSystemPromptFromConfig, +} from "@/agent/utils"; +import { ARTIFACT_TOOL_SCHEMA } from "./schemas"; +import { ArtifactV3 } from "@/types"; +import { createArtifactContent, formatNewArtifactPrompt } from "./utils"; + +/** + * Generate a new artifact based on the user's query. + */ +export const generateArtifact = async ( + state: typeof OpenCanvasGraphAnnotation.State, + config: LangGraphRunnableConfig +): Promise => { + const { modelName } = getModelConfig(config); + const smallModel = await getModelFromConfig(config, { + temperature: 0.5, + }); + + const modelWithArtifactTool = smallModel.bindTools( + [ + { + name: "generate_artifact", + schema: ARTIFACT_TOOL_SCHEMA, + }, + ], + { tool_choice: "generate_artifact" } + ); + + const memoriesAsString = await getFormattedReflections(config); + const formattedNewArtifactPrompt = formatNewArtifactPrompt( + memoriesAsString, + modelName + ); + + const userSystemPrompt = optionallyGetSystemPromptFromConfig(config); + const fullSystemPrompt = userSystemPrompt + ? `${userSystemPrompt}\n${formattedNewArtifactPrompt}` + : formattedNewArtifactPrompt; + + const response = await modelWithArtifactTool.invoke( + [{ role: "system", content: fullSystemPrompt }, ...state.messages], + { runName: "generate_artifact" } + ); + + const newArtifactContent = createArtifactContent(response.tool_calls?.[0]); + const newArtifact: ArtifactV3 = { + currentIndex: 1, + contents: [newArtifactContent], + }; + + return { + artifact: newArtifact, + }; +}; diff --git a/src/agent/open-canvas/nodes/generate-artifact/schemas.ts b/src/agent/open-canvas/nodes/generate-artifact/schemas.ts new file mode 100644 index 00000000..2e2d4a4f --- /dev/null +++ b/src/agent/open-canvas/nodes/generate-artifact/schemas.ts @@ -0,0 +1,33 @@ +import { PROGRAMMING_LANGUAGES } from "@/types"; +import { z } from "zod"; + +export const ARTIFACT_TOOL_SCHEMA = z.object({ + type: z + .enum(["code", "text"]) + .describe("The content type of the artifact generated."), + language: z + .enum( + PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [ + string, + ...string[], + ] + ) + .optional() + .describe( + "The language/programming language of the artifact generated.\n" + + "If generating code, it should be one of the options, or 'other'.\n" + + "If not generating code, the language should ALWAYS be 'other'." + ), + isValidReact: z + .boolean() + .optional() + .describe( + "Whether or not the generated code is valid React code. Only populate this field if generating code." + ), + artifact: z.string().describe("The content of the artifact to generate."), + title: z + .string() + .describe( + "A short title to give to the artifact. Should be less than 5 words." + ), +}); diff --git a/src/agent/open-canvas/nodes/generate-artifact/utils.ts b/src/agent/open-canvas/nodes/generate-artifact/utils.ts new file mode 100644 index 00000000..4e6aca8e --- /dev/null +++ b/src/agent/open-canvas/nodes/generate-artifact/utils.ts @@ -0,0 +1,39 @@ +import { NEW_ARTIFACT_PROMPT } from "../../prompts"; +import { ArtifactCodeV3, ArtifactMarkdownV3 } from "@/types"; +import { ToolCall } from "@langchain/core/messages/tool"; + +export const formatNewArtifactPrompt = ( + memoriesAsString: string, + modelName: string +): string => { + return NEW_ARTIFACT_PROMPT.replace("{reflections}", memoriesAsString).replace( + "{disableChainOfThought}", + modelName.includes("claude") + ? "\n\nIMPORTANT: Do NOT preform chain of thought beforehand. Instead, go STRAIGHT to generating the tool response. This is VERY important." + : "" + ); +}; + +export const createArtifactContent = ( + toolCall: ToolCall | undefined +): ArtifactCodeV3 | ArtifactMarkdownV3 => { + const toolArgs = toolCall?.args; + const artifactType = toolArgs?.type; + + if (artifactType === "code") { + return { + index: 1, + type: "code", + title: toolArgs?.title, + code: toolArgs?.artifact, + language: toolArgs?.language, + }; + } + + return { + index: 1, + type: "text", + title: toolArgs?.title, + fullMarkdown: toolArgs?.artifact, + }; +}; diff --git a/src/agent/open-canvas/nodes/generateArtifact.ts b/src/agent/open-canvas/nodes/generateArtifact.ts deleted file mode 100644 index bd61bd46..00000000 --- a/src/agent/open-canvas/nodes/generateArtifact.ts +++ /dev/null @@ -1,131 +0,0 @@ -import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; -import { NEW_ARTIFACT_PROMPT } from "../prompts"; -import { - ArtifactCodeV3, - ArtifactMarkdownV3, - ArtifactV3, - PROGRAMMING_LANGUAGES, - Reflections, -} from "../../../types"; -import { z } from "zod"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; -import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; - -/** - * Generate a new artifact based on the user's query. - */ -export const generateArtifact = async ( - state: typeof OpenCanvasGraphAnnotation.State, - config: LangGraphRunnableConfig -): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { - temperature: 0.5, - modelProvider, - }); - - const store = ensureStoreInConfig(config); - const assistantId = config.configurable?.assistant_id; - if (!assistantId) { - throw new Error("`assistant_id` not found in configurable"); - } - const memoryNamespace = ["memories", assistantId]; - const memoryKey = "reflection"; - const memories = await store.get(memoryNamespace, memoryKey); - const memoriesAsString = memories?.value - ? formatReflections(memories.value as Reflections) - : "No reflections found."; - - const modelWithArtifactTool = smallModel.bindTools( - [ - { - name: "generate_artifact", - schema: z.object({ - type: z - .enum(["code", "text"]) - .describe("The content type of the artifact generated."), - language: z - .enum( - PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [ - string, - ...string[], - ] - ) - .optional() - .describe( - "The language/programming language of the artifact generated.\n" + - "If generating code, it should be one of the options, or 'other'.\n" + - "If not generating code, the language should ALWAYS be 'other'." - ), - isValidReact: z - .boolean() - .optional() - .describe( - "Whether or not the generated code is valid React code. Only populate this field if generating code." - ), - artifact: z - .string() - .describe("The content of the artifact to generate."), - title: z - .string() - .describe( - "A short title to give to the artifact. Should be less than 5 words." - ), - }), - }, - ], - { tool_choice: "generate_artifact" } - ); - - const formattedNewArtifactPrompt = NEW_ARTIFACT_PROMPT.replace( - "{reflections}", - memoriesAsString - ).replace( - "{disableChainOfThought}", - modelName.includes("claude") - ? "\n\nIMPORTANT: Do NOT preform chain of thought beforehand. Instead, go STRAIGHT to generating the tool response. This is VERY important." - : "" - ); - - const response = await modelWithArtifactTool.invoke( - [ - { role: "system", content: formattedNewArtifactPrompt }, - ...state.messages, - ], - { runName: "generate_artifact" } - ); - - const newArtifactType = response.tool_calls?.[0]?.args.type; - let newArtifactContent: ArtifactCodeV3 | ArtifactMarkdownV3; - if (newArtifactType === "code") { - newArtifactContent = { - index: 1, - type: "code", - title: response.tool_calls?.[0]?.args.title, - code: response.tool_calls?.[0]?.args.artifact, - language: response.tool_calls?.[0]?.args.language, - }; - } else { - newArtifactContent = { - index: 1, - type: "text", - title: response.tool_calls?.[0]?.args.title, - fullMarkdown: response.tool_calls?.[0]?.args.artifact, - }; - } - - const newArtifact: ArtifactV3 = { - currentIndex: 1, - contents: [newArtifactContent], - }; - - return { - artifact: newArtifact, - }; -}; diff --git a/src/agent/open-canvas/nodes/generateFollowup.ts b/src/agent/open-canvas/nodes/generateFollowup.ts index 63a9d36f..407b3fdf 100644 --- a/src/agent/open-canvas/nodes/generateFollowup.ts +++ b/src/agent/open-canvas/nodes/generateFollowup.ts @@ -1,13 +1,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { Reflections } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; @@ -18,12 +14,8 @@ export const generateFollowup = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { - temperature: 0.5, + const smallModel = await getModelFromConfig(config, { maxTokens: 250, - modelProvider, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/generatePath.ts b/src/agent/open-canvas/nodes/generatePath.ts index 03a0350b..9ee358ed 100644 --- a/src/agent/open-canvas/nodes/generatePath.ts +++ b/src/agent/open-canvas/nodes/generatePath.ts @@ -7,12 +7,9 @@ import { } from "../prompts"; import { OpenCanvasGraphAnnotation } from "../state"; import { z } from "zod"; -import { - formatArtifactContentWithTemplate, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { formatArtifactContentWithTemplate } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; /** @@ -22,6 +19,7 @@ export const generatePath = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ) => { + console.log("config.configurable!!", config.configurable); if (state.highlightedCode) { return { next: "updateArtifact", @@ -93,11 +91,8 @@ export const generatePath = async ( ? "rewriteArtifact" : "generateArtifact"; - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const model = await initChatModel(modelName, { + const model = await getModelFromConfig(config, { temperature: 0, - modelProvider, }); const modelWithTool = model.withStructuredOutput( z.object({ diff --git a/src/agent/open-canvas/nodes/replyToGeneralInput.ts b/src/agent/open-canvas/nodes/replyToGeneralInput.ts index 5cd71d2f..ad9ebd19 100644 --- a/src/agent/open-canvas/nodes/replyToGeneralInput.ts +++ b/src/agent/open-canvas/nodes/replyToGeneralInput.ts @@ -1,12 +1,11 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { Reflections } from "../../../types"; import { ensureStoreInConfig, formatArtifactContentWithTemplate, formatReflections, - getModelNameAndProviderFromConfig, } from "../../utils"; import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; @@ -18,12 +17,7 @@ export const replyToGeneralInput = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { - temperature: 0.5, - modelProvider, - }); + const smallModel = await getModelFromConfig(config); const prompt = `You are an AI assistant tasked with responding to the users question. diff --git a/src/agent/open-canvas/nodes/rewrite-artifact/index.ts b/src/agent/open-canvas/nodes/rewrite-artifact/index.ts new file mode 100644 index 00000000..7615366d --- /dev/null +++ b/src/agent/open-canvas/nodes/rewrite-artifact/index.ts @@ -0,0 +1,68 @@ +import { + OpenCanvasGraphAnnotation, + OpenCanvasGraphReturnType, +} from "../../state"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; +import { optionallyUpdateArtifactMeta } from "./update-meta"; +import { buildPrompt, createNewArtifactContent, validateState } from "./utils"; +import { + getFormattedReflections, + getModelFromConfig, + optionallyGetSystemPromptFromConfig, +} from "@/agent/utils"; +import { isArtifactMarkdownContent } from "@/lib/artifact_content_types"; + +export const rewriteArtifact = async ( + state: typeof OpenCanvasGraphAnnotation.State, + config: LangGraphRunnableConfig +): Promise => { + const smallModelWithConfig = (await getModelFromConfig(config)).withConfig({ + runName: "rewrite_artifact_model_call", + }); + const memoriesAsString = await getFormattedReflections(config); + const { currentArtifactContent, recentHumanMessage } = validateState(state); + + const artifactMetaToolCall = await optionallyUpdateArtifactMeta( + state, + config + ); + const artifactType = artifactMetaToolCall?.args?.type; + const isNewType = artifactType !== currentArtifactContent.type; + + const artifactContent = isArtifactMarkdownContent(currentArtifactContent) + ? currentArtifactContent.fullMarkdown + : currentArtifactContent.code; + + const formattedPrompt = buildPrompt({ + artifactContent, + memoriesAsString, + isNewType, + artifactMetaToolCall, + }); + + const userSystemPrompt = optionallyGetSystemPromptFromConfig(config); + const fullSystemPrompt = userSystemPrompt + ? `${userSystemPrompt}\n${formattedPrompt}` + : formattedPrompt; + + const newArtifactResponse = await smallModelWithConfig.invoke([ + { role: "system", content: fullSystemPrompt }, + recentHumanMessage, + ]); + + const newArtifactContent = createNewArtifactContent({ + artifactType, + state, + currentArtifactContent, + artifactMetaToolCall, + newContent: newArtifactResponse.content as string, + }); + + return { + artifact: { + ...state.artifact, + currentIndex: state.artifact.contents.length + 1, + contents: [...state.artifact.contents, newArtifactContent], + }, + }; +}; diff --git a/src/agent/open-canvas/nodes/rewrite-artifact/schemas.ts b/src/agent/open-canvas/nodes/rewrite-artifact/schemas.ts new file mode 100644 index 00000000..080e704a --- /dev/null +++ b/src/agent/open-canvas/nodes/rewrite-artifact/schemas.ts @@ -0,0 +1,22 @@ +import { PROGRAMMING_LANGUAGES } from "@/types"; +import { z } from "zod"; + +export const OPTIONALLY_UPDATE_ARTIFACT_META_SCHEMA = z.object({ + type: z.enum(["text", "code"]).describe("The type of the artifact content."), + title: z + .string() + .optional() + .describe( + "The new title to give the artifact. ONLY update this if the user is making a request which changes the subject/topic of the artifact." + ), + language: z + .enum( + PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [ + string, + ...string[], + ] + ) + .describe( + "The language of the code artifact. This should be populated with the programming language if the user is requesting code to be written, or 'other', in all other cases." + ), +}); diff --git a/src/agent/open-canvas/nodes/rewrite-artifact/update-meta.ts b/src/agent/open-canvas/nodes/rewrite-artifact/update-meta.ts new file mode 100644 index 00000000..e0e2b11c --- /dev/null +++ b/src/agent/open-canvas/nodes/rewrite-artifact/update-meta.ts @@ -0,0 +1,55 @@ +import { LangGraphRunnableConfig } from "@langchain/langgraph"; +import { OpenCanvasGraphAnnotation } from "../../state"; +import { formatArtifactContent, getModelFromConfig } from "@/agent/utils"; +import { getArtifactContent } from "@/contexts/utils"; +import { GET_TITLE_TYPE_REWRITE_ARTIFACT } from "../../prompts"; +import { OPTIONALLY_UPDATE_ARTIFACT_META_SCHEMA } from "./schemas"; +import { ToolCall } from "@langchain/core/messages/tool"; +import { getFormattedReflections } from "../../../utils"; + +export async function optionallyUpdateArtifactMeta( + state: typeof OpenCanvasGraphAnnotation.State, + config: LangGraphRunnableConfig +): Promise { + const toolCallingModel = (await getModelFromConfig(config)) + .bindTools( + [ + { + name: "optionallyUpdateArtifactMeta", + schema: OPTIONALLY_UPDATE_ARTIFACT_META_SCHEMA, + description: "Update the artifact meta information, if necessary.", + }, + ], + { tool_choice: "optionallyUpdateArtifactMeta" } + ) + .withConfig({ runName: "optionally_update_artifact_meta" }); + + const memoriesAsString = await getFormattedReflections(config); + + const currentArtifactContent = state.artifact + ? getArtifactContent(state.artifact) + : undefined; + if (!currentArtifactContent) { + throw new Error("No artifact found"); + } + + const optionallyUpdateArtifactMetaPrompt = + GET_TITLE_TYPE_REWRITE_ARTIFACT.replace( + "{artifact}", + formatArtifactContent(currentArtifactContent, true) + ).replace("{reflections}", memoriesAsString); + + const recentHumanMessage = state.messages.findLast( + (message) => message.getType() === "human" + ); + if (!recentHumanMessage) { + throw new Error("No recent human message found"); + } + + const optionallyUpdateArtifactResponse = await toolCallingModel.invoke([ + { role: "system", content: optionallyUpdateArtifactMetaPrompt }, + recentHumanMessage, + ]); + + return optionallyUpdateArtifactResponse.tool_calls?.[0]; +} diff --git a/src/agent/open-canvas/nodes/rewrite-artifact/utils.ts b/src/agent/open-canvas/nodes/rewrite-artifact/utils.ts new file mode 100644 index 00000000..3dbac450 --- /dev/null +++ b/src/agent/open-canvas/nodes/rewrite-artifact/utils.ts @@ -0,0 +1,110 @@ +import { getArtifactContent } from "@/contexts/utils"; +import { isArtifactCodeContent } from "@/lib/artifact_content_types"; +import { ArtifactCodeV3, ArtifactMarkdownV3 } from "@/types"; +import { + OPTIONALLY_UPDATE_META_PROMPT, + UPDATE_ENTIRE_ARTIFACT_PROMPT, +} from "../../prompts"; +import { OpenCanvasGraphAnnotation } from "../../state"; +import { ToolCall } from "@langchain/core/messages/tool"; + +export const validateState = ( + state: typeof OpenCanvasGraphAnnotation.State +) => { + const currentArtifactContent = state.artifact + ? getArtifactContent(state.artifact) + : undefined; + if (!currentArtifactContent) { + throw new Error("No artifact found"); + } + + const recentHumanMessage = state.messages.findLast( + (message) => message.getType() === "human" + ); + if (!recentHumanMessage) { + throw new Error("No recent human message found"); + } + + return { currentArtifactContent, recentHumanMessage }; +}; + +const buildMetaPrompt = (artifactMetaToolCall: ToolCall | undefined) => { + const titleSection = + artifactMetaToolCall?.args?.title && + artifactMetaToolCall?.args?.type !== "code" + ? `And its title is (do NOT include this in your response):\n${artifactMetaToolCall.args.title}` + : ""; + + return OPTIONALLY_UPDATE_META_PROMPT.replace( + "{artifactType}", + artifactMetaToolCall?.args?.type + ).replace("{artifactTitle}", titleSection); +}; + +interface BuildPromptArgs { + artifactContent: string; + memoriesAsString: string; + isNewType: boolean; + artifactMetaToolCall: ToolCall | undefined; +} + +export const buildPrompt = ({ + artifactContent, + memoriesAsString, + isNewType, + artifactMetaToolCall, +}: BuildPromptArgs) => { + const metaPrompt = isNewType ? buildMetaPrompt(artifactMetaToolCall) : ""; + + return UPDATE_ENTIRE_ARTIFACT_PROMPT.replace( + "{artifactContent}", + artifactContent + ) + .replace("{reflections}", memoriesAsString) + .replace("{updateMetaPrompt}", metaPrompt); +}; + +interface CreateNewArtifactContentArgs { + artifactType: string; + state: typeof OpenCanvasGraphAnnotation.State; + currentArtifactContent: ArtifactCodeV3 | ArtifactMarkdownV3; + artifactMetaToolCall: ToolCall | undefined; + newContent: string; +} + +export const createNewArtifactContent = ({ + artifactType, + state, + currentArtifactContent, + artifactMetaToolCall, + newContent, +}: CreateNewArtifactContentArgs): ArtifactCodeV3 | ArtifactMarkdownV3 => { + const baseContent = { + index: state.artifact.contents.length + 1, + title: artifactMetaToolCall?.args?.title || currentArtifactContent.title, + }; + + if (artifactType === "code") { + return { + ...baseContent, + type: "code", + language: getLanguage(artifactMetaToolCall, currentArtifactContent), + code: newContent, + }; + } + + return { + ...baseContent, + type: "text", + fullMarkdown: newContent, + }; +}; + +const getLanguage = ( + artifactMetaToolCall: ToolCall | undefined, + currentArtifactContent: ArtifactCodeV3 | ArtifactMarkdownV3 // Replace 'any' with proper type +) => + artifactMetaToolCall?.args?.programmingLanguage || + (isArtifactCodeContent(currentArtifactContent) + ? currentArtifactContent.language + : "other"); diff --git a/src/agent/open-canvas/nodes/rewriteArtifact.ts b/src/agent/open-canvas/nodes/rewriteArtifact.ts deleted file mode 100644 index 067a522f..00000000 --- a/src/agent/open-canvas/nodes/rewriteArtifact.ts +++ /dev/null @@ -1,184 +0,0 @@ -import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; -import { - GET_TITLE_TYPE_REWRITE_ARTIFACT, - OPTIONALLY_UPDATE_META_PROMPT, - UPDATE_ENTIRE_ARTIFACT_PROMPT, -} from "../prompts"; -import { - ensureStoreInConfig, - formatArtifactContent, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; -import { - ArtifactCodeV3, - ArtifactMarkdownV3, - ArtifactV3, - PROGRAMMING_LANGUAGES, - Reflections, -} from "../../../types"; -import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { z } from "zod"; -import { getArtifactContent } from "../../../contexts/utils"; -import { - isArtifactCodeContent, - isArtifactMarkdownContent, -} from "../../../lib/artifact_content_types"; -import { initChatModel } from "langchain/chat_models/universal"; - -export const rewriteArtifact = async ( - state: typeof OpenCanvasGraphAnnotation.State, - config: LangGraphRunnableConfig -): Promise => { - const optionallyUpdateArtifactMetaSchema = z.object({ - type: z - .enum(["text", "code"]) - .describe("The type of the artifact content."), - title: z - .string() - .optional() - .describe( - "The new title to give the artifact. ONLY update this if the user is making a request which changes the subject/topic of the artifact." - ), - language: z - .enum( - PROGRAMMING_LANGUAGES.map((lang) => lang.language) as [ - string, - ...string[], - ] - ) - .describe( - "The language of the code artifact. This should be populated with the programming language if the user is requesting code to be written, or 'other', in all other cases." - ), - }); - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const toolCallingModel = ( - await initChatModel(modelName, { - temperature: 0, - modelProvider, - }) - ) - .bindTools( - [ - { - name: "optionallyUpdateArtifactMeta", - schema: optionallyUpdateArtifactMetaSchema, - description: "Update the artifact meta information, if necessary.", - }, - ], - { tool_choice: "optionallyUpdateArtifactMeta" } - ) - .withConfig({ runName: "optionally_update_artifact_meta" }); - - const smallModelWithConfig = ( - await initChatModel(modelName, { - temperature: 0, - modelProvider, - }) - ).withConfig({ - runName: "rewrite_artifact_model_call", - }); - - const store = ensureStoreInConfig(config); - const assistantId = config.configurable?.assistant_id; - if (!assistantId) { - throw new Error("`assistant_id` not found in configurable"); - } - const memoryNamespace = ["memories", assistantId]; - const memoryKey = "reflection"; - const memories = await store.get(memoryNamespace, memoryKey); - const memoriesAsString = memories?.value - ? formatReflections(memories.value as Reflections) - : "No reflections found."; - - const currentArtifactContent = state.artifact - ? getArtifactContent(state.artifact) - : undefined; - if (!currentArtifactContent) { - throw new Error("No artifact found"); - } - - const optionallyUpdateArtifactMetaPrompt = - GET_TITLE_TYPE_REWRITE_ARTIFACT.replace( - "{artifact}", - formatArtifactContent(currentArtifactContent, true) - ).replace("{reflections}", memoriesAsString); - - const recentHumanMessage = state.messages.findLast( - (message) => message.getType() === "human" - ); - if (!recentHumanMessage) { - throw new Error("No recent human message found"); - } - - const optionallyUpdateArtifactResponse = await toolCallingModel.invoke([ - { role: "system", content: optionallyUpdateArtifactMetaPrompt }, - recentHumanMessage, - ]); - const artifactMetaToolCall = optionallyUpdateArtifactResponse.tool_calls?.[0]; - const artifactType = artifactMetaToolCall?.args?.type; - const isNewType = artifactType !== currentArtifactContent.type; - - const artifactContent = isArtifactMarkdownContent(currentArtifactContent) - ? currentArtifactContent.fullMarkdown - : currentArtifactContent.code; - - const formattedPrompt = UPDATE_ENTIRE_ARTIFACT_PROMPT.replace( - "{artifactContent}", - artifactContent - ) - .replace("{reflections}", memoriesAsString) - .replace( - "{updateMetaPrompt}", - isNewType - ? OPTIONALLY_UPDATE_META_PROMPT.replace( - "{artifactType}", - artifactMetaToolCall?.args?.type - ).replace( - "{artifactTitle}", - artifactMetaToolCall?.args?.title && - artifactMetaToolCall?.args?.type !== "code" - ? `And its title is (do NOT include this in your response):\n${artifactMetaToolCall?.args?.title}` - : "" - ) - : "" - ); - - const newArtifactResponse = await smallModelWithConfig.invoke([ - { role: "system", content: formattedPrompt }, - recentHumanMessage, - ]); - - let newArtifactContent: ArtifactCodeV3 | ArtifactMarkdownV3; - if (artifactType === "code") { - newArtifactContent = { - index: state.artifact.contents.length + 1, - type: "code", - title: artifactMetaToolCall?.args?.title || currentArtifactContent.title, - language: - artifactMetaToolCall?.args?.programmingLanguage || - (isArtifactCodeContent(currentArtifactContent) - ? currentArtifactContent.language - : "other"), - code: newArtifactResponse.content as string, - }; - } else { - newArtifactContent = { - index: state.artifact.contents.length + 1, - type: "text", - title: artifactMetaToolCall?.args?.title || currentArtifactContent.title, - fullMarkdown: newArtifactResponse.content as string, - }; - } - - const newArtifact: ArtifactV3 = { - ...state.artifact, - currentIndex: state.artifact.contents.length + 1, - contents: [...state.artifact.contents, newArtifactContent], - }; - - return { - artifact: newArtifact, - }; -}; diff --git a/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts b/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts index 0d5c73e5..3bb39e44 100644 --- a/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts +++ b/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts @@ -1,13 +1,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { ArtifactV3, Reflections } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { ADD_EMOJIS_TO_ARTIFACT_PROMPT, CHANGE_ARTIFACT_LANGUAGE_PROMPT, @@ -21,12 +17,7 @@ export const rewriteArtifactTheme = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { - temperature: 0.5, - modelProvider, - }); + const smallModel = await getModelFromConfig(config); const store = ensureStoreInConfig(config); const assistantId = config.configurable?.assistant_id; diff --git a/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts b/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts index b1203c82..d8fe34e7 100644 --- a/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts +++ b/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts @@ -1,6 +1,5 @@ -import { getModelNameAndProviderFromConfig } from "@/agent/utils"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { getModelFromConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactCodeContent } from "../../../lib/artifact_content_types"; import { ArtifactCodeV3, ArtifactV3 } from "../../../types"; @@ -16,12 +15,7 @@ export const rewriteCodeArtifactTheme = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { - temperature: 0.5, - modelProvider, - }); + const smallModel = await getModelFromConfig(config); const currentArtifactContent = state.artifact ? getArtifactContent(state.artifact) diff --git a/src/agent/open-canvas/nodes/updateArtifact.ts b/src/agent/open-canvas/nodes/updateArtifact.ts index 74f97b60..918b705b 100644 --- a/src/agent/open-canvas/nodes/updateArtifact.ts +++ b/src/agent/open-canvas/nodes/updateArtifact.ts @@ -1,7 +1,11 @@ -import { ChatOpenAI } from "@langchain/openai"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; import { UPDATE_HIGHLIGHTED_ARTIFACT_PROMPT } from "../prompts"; -import { ensureStoreInConfig, formatReflections } from "../../utils"; +import { + ensureStoreInConfig, + formatReflections, + getModelConfig, + getModelFromConfig, +} from "../../utils"; import { ArtifactCodeV3, ArtifactV3, Reflections } from "../../../types"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { getArtifactContent } from "../../../contexts/utils"; @@ -14,10 +18,27 @@ export const updateArtifact = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const smallModel = new ChatOpenAI({ - model: "gpt-4o", - temperature: 0, - }); + const { modelProvider } = getModelConfig(config); + let smallModel: Awaited>; + if (modelProvider.includes("openai")) { + // Custom model is OpenAI/Azure OpenAI + smallModel = await getModelFromConfig(config, { + temperature: 0, + }); + } else { + // Custom model is not set to OpenAI/Azure OpenAI. Use GPT-4o + smallModel = await getModelFromConfig( + { + ...config, + configurable: { + customModelName: "gpt-4o", + }, + }, + { + temperature: 0, + } + ); + } const store = ensureStoreInConfig(config); const assistantId = config.configurable?.assistant_id; diff --git a/src/agent/open-canvas/nodes/updateHighlightedText.ts b/src/agent/open-canvas/nodes/updateHighlightedText.ts index e37f441c..8c135e93 100644 --- a/src/agent/open-canvas/nodes/updateHighlightedText.ts +++ b/src/agent/open-canvas/nodes/updateHighlightedText.ts @@ -1,8 +1,13 @@ -import { ChatOpenAI } from "@langchain/openai"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; import { ArtifactMarkdownV3 } from "../../../types"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; +import { getModelConfig, getModelFromConfig } from "@/agent/utils"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; +import { RunnableBinding } from "@langchain/core/runnables"; +import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ConfigurableChatModelCallOptions } from "langchain/chat_models/universal"; const PROMPT = `You are an expert AI writing assistant, tasked with rewriting some text a user has selected. The selected text is nested inside a larger 'block'. You should always respond with ONLY the updated text block in accordance with the user's request. You should always respond with the full markdown text block, as it will simply replace the existing block in the artifact. @@ -27,12 +32,38 @@ Ensure you reply with the FULL text block, including the updated selected text. * Update an existing artifact based on the user's query. */ export const updateHighlightedText = async ( - state: typeof OpenCanvasGraphAnnotation.State + state: typeof OpenCanvasGraphAnnotation.State, + config: LangGraphRunnableConfig ): Promise => { - const model = new ChatOpenAI({ - model: "gpt-4o", - temperature: 0, - }).withConfig({ runName: "update_highlighted_markdown" }); + const { modelProvider } = getModelConfig(config); + let model: RunnableBinding< + BaseLanguageModelInput, + AIMessageChunk, + ConfigurableChatModelCallOptions + >; + if (modelProvider.includes("openai")) { + // Custom model is OpenAI/Azure OpenAI + model = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withConfig({ runName: "update_highlighted_markdown" }); + } else { + // Custom model is not set to OpenAI/Azure OpenAI. Use GPT-4o + model = ( + await getModelFromConfig( + { + ...config, + configurable: { + customModelName: "gpt-4o", + }, + }, + { + temperature: 0, + } + ) + ).withConfig({ runName: "update_highlighted_markdown" }); + } const currentArtifactContent = state.artifact ? getArtifactContent(state.artifact) diff --git a/src/agent/utils.ts b/src/agent/utils.ts index e58c64cc..5bbe9364 100644 --- a/src/agent/utils.ts +++ b/src/agent/utils.ts @@ -1,6 +1,7 @@ import { isArtifactCodeContent } from "@/lib/artifact_content_types"; import { BaseStore, LangGraphRunnableConfig } from "@langchain/langgraph"; import { ArtifactCodeV3, ArtifactMarkdownV3, Reflections } from "../types"; +import { initChatModel } from "langchain/chat_models/universal"; export const formatReflections = ( reflections: Reflections, @@ -74,6 +75,24 @@ export const formatReflections = ( return styleString + "\n\n" + contentString; }; +export async function getFormattedReflections( + config: LangGraphRunnableConfig +): Promise { + const store = ensureStoreInConfig(config); + const assistantId = config.configurable?.assistant_id; + if (!assistantId) { + throw new Error("`assistant_id` not found in configurable"); + } + const memoryNamespace = ["memories", assistantId]; + const memoryKey = "reflection"; + const memories = await store.get(memoryNamespace, memoryKey); + const memoriesAsString = memories?.value + ? formatReflections(memories.value as Reflections) + : "No reflections found."; + + return memoriesAsString; +} + export const ensureStoreInConfig = ( config: LangGraphRunnableConfig ): BaseStore => { @@ -112,13 +131,42 @@ export const formatArtifactContentWithTemplate = ( ); }; -export const getModelNameAndProviderFromConfig = ( +export const getModelConfig = ( config: LangGraphRunnableConfig -): { modelName: string; modelProvider: string } => { +): { + modelName: string; + modelProvider: string; + azureConfig?: { + azureOpenAIApiKey: string; + azureOpenAIApiInstanceName: string; + azureOpenAIApiDeploymentName: string; + azureOpenAIApiVersion: string; + azureOpenAIBasePath?: string; + }; +} => { const customModelName = config.configurable?.customModelName as string; if (!customModelName) { throw new Error("Model name is missing in config."); } + + if (customModelName.startsWith("azure/")) { + const actualModelName = customModelName.replace("azure/", ""); + return { + modelName: actualModelName, + modelProvider: "azure_openai", + azureConfig: { + azureOpenAIApiKey: process.env._AZURE_OPENAI_API_KEY || "", + azureOpenAIApiInstanceName: + process.env._AZURE_OPENAI_API_INSTANCE_NAME || "", + azureOpenAIApiDeploymentName: + process.env._AZURE_OPENAI_API_DEPLOYMENT_NAME || "", + azureOpenAIApiVersion: + process.env._AZURE_OPENAI_API_VERSION || "2024-08-01-preview", + azureOpenAIBasePath: process.env._AZURE_OPENAI_API_BASE_PATH, + }, + }; + } + if (customModelName.includes("gpt-")) { return { modelName: customModelName, @@ -146,3 +194,35 @@ export const getModelNameAndProviderFromConfig = ( throw new Error("Unknown model provider"); }; + +export function optionallyGetSystemPromptFromConfig( + config: LangGraphRunnableConfig +): string | undefined { + return config.configurable?.systemPrompt as string | undefined; +} + +export async function getModelFromConfig( + config: LangGraphRunnableConfig, + extra?: { + temperature?: number; + maxTokens?: number; + } +) { + const { temperature = 0.5, maxTokens } = extra || {}; + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + return await initChatModel(modelName, { + modelProvider, + temperature, + maxTokens, + ...(azureConfig != null + ? { + azureOpenAIApiKey: azureConfig.azureOpenAIApiKey, + azureOpenAIApiInstanceName: azureConfig.azureOpenAIApiInstanceName, + azureOpenAIApiDeploymentName: + azureConfig.azureOpenAIApiDeploymentName, + azureOpenAIApiVersion: azureConfig.azureOpenAIApiVersion, + azureOpenAIBasePath: azureConfig.azureOpenAIBasePath, + } + : {}), + }); +} diff --git a/src/components/artifacts/ArtifactRenderer.tsx b/src/components/artifacts/ArtifactRenderer.tsx index 5d4e5e87..ee79da2e 100644 --- a/src/components/artifacts/ArtifactRenderer.tsx +++ b/src/components/artifacts/ArtifactRenderer.tsx @@ -117,7 +117,7 @@ function NavigateArtifactHistory(props: NavigateArtifactHistoryProps) { function ArtifactRendererComponent(props: ArtifactRendererProps) { const { graphData, - threadData: { assistantId }, + assistantsData: { selectedAssistant }, userData: { user }, } = useGraphContext(); const { @@ -354,7 +354,7 @@ function ArtifactRendererComponent(props: ArtifactRendererProps) { />
- +
diff --git a/src/components/artifacts/actions_toolbar/custom/NewCustomQuickActionDialog.tsx b/src/components/artifacts/actions_toolbar/custom/NewCustomQuickActionDialog.tsx index c1ed520e..f6210264 100644 --- a/src/components/artifacts/actions_toolbar/custom/NewCustomQuickActionDialog.tsx +++ b/src/components/artifacts/actions_toolbar/custom/NewCustomQuickActionDialog.tsx @@ -20,7 +20,7 @@ import { useState, } from "react"; import { FullPrompt } from "./FullPrompt"; -import { InlineContextTooltip } from "./PromptContextTooltip"; +import { InlineContextTooltip } from "@/components/ui/inline-context-tooltip"; import { useStore } from "@/hooks/useStore"; import { useToast } from "@/hooks/use-toast"; import { v4 as uuidv4 } from "uuid"; @@ -28,6 +28,9 @@ import { CustomQuickAction } from "@/types"; import { TighterText } from "@/components/ui/header"; import { User } from "@supabase/supabase-js"; +const CUSTOM_INSTRUCTIONS_TOOLTIP_TEXT = `This field contains the custom instructions you set, which will then be used to instruct the LLM on how to re-generate the selected artifact.`; +const FULL_PROMPT_TOOLTIP_TEXT = `This is the full prompt that will be set to the LLM when you invoke this quick action, including your custom instructions and other default context.`; + interface NewCustomQuickActionDialogProps { user: User | undefined; isEditing: boolean; @@ -224,7 +227,11 @@ export function NewCustomQuickActionDialog(
Custom instructions - + +

+ {CUSTOM_INSTRUCTIONS_TOOLTIP_TEXT} +

+