From 0019011abd04422c8ecbc4c6626d73182d54adca Mon Sep 17 00:00:00 2001 From: punitda Date: Sun, 3 Nov 2024 15:41:23 +0530 Subject: [PATCH] Update nodes to use initChatModelWithConfig --- src/agent/open-canvas/nodes/customAction.ts | 14 +++++-------- .../open-canvas/nodes/generateArtifact.ts | 10 ++++----- .../open-canvas/nodes/generateFollowup.ts | 14 +++++-------- src/agent/open-canvas/nodes/generatePath.ts | 13 +++++------- .../open-canvas/nodes/replyToGeneralInput.ts | 9 ++++---- .../open-canvas/nodes/rewriteArtifact.ts | 21 +++++++++++-------- .../open-canvas/nodes/rewriteArtifactTheme.ts | 14 +++++-------- .../nodes/rewriteCodeArtifactTheme.ts | 10 ++++----- 8 files changed, 46 insertions(+), 59 deletions(-) diff --git a/src/agent/open-canvas/nodes/customAction.ts b/src/agent/open-canvas/nodes/customAction.ts index 052961d1..e02f200c 100644 --- a/src/agent/open-canvas/nodes/customAction.ts +++ b/src/agent/open-canvas/nodes/customAction.ts @@ -1,6 +1,6 @@ import { BaseMessage } from "@langchain/core/messages"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig, getModelConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { @@ -10,11 +10,7 @@ import { CustomQuickAction, Reflections, } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT, CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX, @@ -39,11 +35,11 @@ export const customAction = async ( throw new Error("No custom quick action ID found."); } - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, modelProvider, + azureConfig, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/generateArtifact.ts b/src/agent/open-canvas/nodes/generateArtifact.ts index bd61bd46..8ce0e536 100644 --- a/src/agent/open-canvas/nodes/generateArtifact.ts +++ b/src/agent/open-canvas/nodes/generateArtifact.ts @@ -11,10 +11,10 @@ import { z } from "zod"; import { ensureStoreInConfig, formatReflections, - getModelNameAndProviderFromConfig, + getModelConfig, + initChatModelWithConfig, } from "../../utils"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; /** * Generate a new artifact based on the user's query. @@ -23,11 +23,11 @@ export const generateArtifact = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, modelProvider, + azureConfig, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/generateFollowup.ts b/src/agent/open-canvas/nodes/generateFollowup.ts index 63a9d36f..d1cf044c 100644 --- a/src/agent/open-canvas/nodes/generateFollowup.ts +++ b/src/agent/open-canvas/nodes/generateFollowup.ts @@ -1,13 +1,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig, getModelConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { Reflections } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; @@ -18,12 +14,12 @@ export const generateFollowup = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, maxTokens: 250, modelProvider, + azureConfig, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/generatePath.ts b/src/agent/open-canvas/nodes/generatePath.ts index 03a0350b..a90d2168 100644 --- a/src/agent/open-canvas/nodes/generatePath.ts +++ b/src/agent/open-canvas/nodes/generatePath.ts @@ -7,12 +7,9 @@ import { } from "../prompts"; import { OpenCanvasGraphAnnotation } from "../state"; import { z } from "zod"; -import { - formatArtifactContentWithTemplate, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { formatArtifactContentWithTemplate, getModelConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig } from "../../utils"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; /** @@ -93,11 +90,11 @@ export const generatePath = async ( ? "rewriteArtifact" : "generateArtifact"; - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const model = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const model = await initChatModelWithConfig(modelName, { temperature: 0, modelProvider, + azureConfig, }); const modelWithTool = model.withStructuredOutput( z.object({ diff --git a/src/agent/open-canvas/nodes/replyToGeneralInput.ts b/src/agent/open-canvas/nodes/replyToGeneralInput.ts index 5cd71d2f..31653bf2 100644 --- a/src/agent/open-canvas/nodes/replyToGeneralInput.ts +++ b/src/agent/open-canvas/nodes/replyToGeneralInput.ts @@ -1,12 +1,11 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig, getModelConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { Reflections } from "../../../types"; import { ensureStoreInConfig, formatArtifactContentWithTemplate, formatReflections, - getModelNameAndProviderFromConfig, } from "../../utils"; import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts"; import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state"; @@ -18,11 +17,11 @@ export const replyToGeneralInput = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, modelProvider, + azureConfig, }); const prompt = `You are an AI assistant tasked with responding to the users question. diff --git a/src/agent/open-canvas/nodes/rewriteArtifact.ts b/src/agent/open-canvas/nodes/rewriteArtifact.ts index 067a522f..dfc6cde2 100644 --- a/src/agent/open-canvas/nodes/rewriteArtifact.ts +++ b/src/agent/open-canvas/nodes/rewriteArtifact.ts @@ -8,7 +8,7 @@ import { ensureStoreInConfig, formatArtifactContent, formatReflections, - getModelNameAndProviderFromConfig, + getModelConfig, } from "../../utils"; import { ArtifactCodeV3, @@ -24,7 +24,7 @@ import { isArtifactCodeContent, isArtifactMarkdownContent, } from "../../../lib/artifact_content_types"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig } from "../../utils"; export const rewriteArtifact = async ( state: typeof OpenCanvasGraphAnnotation.State, @@ -51,12 +51,15 @@ export const rewriteArtifact = async ( "The language of the code artifact. This should be populated with the programming language if the user is requesting code to be written, or 'other', in all other cases." ), }); - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); + + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + + // Then bind tools to create toolCallingModel const toolCallingModel = ( - await initChatModel(modelName, { + await initChatModelWithConfig(modelName, { temperature: 0, modelProvider, + azureConfig, }) ) .bindTools( @@ -71,14 +74,14 @@ export const rewriteArtifact = async ( ) .withConfig({ runName: "optionally_update_artifact_meta" }); + // Initialize another instance for the second model const smallModelWithConfig = ( - await initChatModel(modelName, { + await initChatModelWithConfig(modelName, { temperature: 0, modelProvider, + azureConfig, }) - ).withConfig({ - runName: "rewrite_artifact_model_call", - }); + ).withConfig({ runName: "rewrite_artifact_model_call" }); const store = ensureStoreInConfig(config); const assistantId = config.configurable?.assistant_id; diff --git a/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts b/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts index 0d5c73e5..65346378 100644 --- a/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts +++ b/src/agent/open-canvas/nodes/rewriteArtifactTheme.ts @@ -1,13 +1,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig, getModelConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types"; import { ArtifactV3, Reflections } from "../../../types"; -import { - ensureStoreInConfig, - formatReflections, - getModelNameAndProviderFromConfig, -} from "../../utils"; +import { ensureStoreInConfig, formatReflections } from "../../utils"; import { ADD_EMOJIS_TO_ARTIFACT_PROMPT, CHANGE_ARTIFACT_LANGUAGE_PROMPT, @@ -21,11 +17,11 @@ export const rewriteArtifactTheme = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, modelProvider, + azureConfig, }); const store = ensureStoreInConfig(config); diff --git a/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts b/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts index b1203c82..ac71a31f 100644 --- a/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts +++ b/src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts @@ -1,6 +1,6 @@ -import { getModelNameAndProviderFromConfig } from "@/agent/utils"; +import { getModelConfig } from "@/agent/utils"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { initChatModel } from "langchain/chat_models/universal"; +import { initChatModelWithConfig } from "../../utils"; import { getArtifactContent } from "../../../contexts/utils"; import { isArtifactCodeContent } from "../../../lib/artifact_content_types"; import { ArtifactCodeV3, ArtifactV3 } from "../../../types"; @@ -16,11 +16,11 @@ export const rewriteCodeArtifactTheme = async ( state: typeof OpenCanvasGraphAnnotation.State, config: LangGraphRunnableConfig ): Promise => { - const { modelName, modelProvider } = - getModelNameAndProviderFromConfig(config); - const smallModel = await initChatModel(modelName, { + const { modelName, modelProvider, azureConfig } = getModelConfig(config); + const smallModel = await initChatModelWithConfig(modelName, { temperature: 0.5, modelProvider, + azureConfig, }); const currentArtifactContent = state.artifact