Skip to content

Commit

Permalink
Update nodes to use initChatModelWithConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
punitda committed Nov 3, 2024
1 parent cd74585 commit 0019011
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 59 deletions.
14 changes: 5 additions & 9 deletions src/agent/open-canvas/nodes/customAction.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BaseMessage } from "@langchain/core/messages";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig, getModelConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import {
Expand All @@ -10,11 +10,7 @@ import {
CustomQuickAction,
Reflections,
} from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import {
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
CUSTOM_QUICK_ACTION_ARTIFACT_PROMPT_PREFIX,
Expand All @@ -39,11 +35,11 @@ export const customAction = async (
throw new Error("No custom quick action ID found.");
}

const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
modelProvider,
azureConfig,
});

const store = ensureStoreInConfig(config);
Expand Down
10 changes: 5 additions & 5 deletions src/agent/open-canvas/nodes/generateArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import { z } from "zod";
import {
ensureStoreInConfig,
formatReflections,
getModelNameAndProviderFromConfig,
getModelConfig,
initChatModelWithConfig,
} from "../../utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";

/**
* Generate a new artifact based on the user's query.
Expand All @@ -23,11 +23,11 @@ export const generateArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
modelProvider,
azureConfig,
});

const store = ensureStoreInConfig(config);
Expand Down
14 changes: 5 additions & 9 deletions src/agent/open-canvas/nodes/generateFollowup.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig, getModelConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";

Expand All @@ -18,12 +14,12 @@ export const generateFollowup = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
maxTokens: 250,
modelProvider,
azureConfig,
});

const store = ensureStoreInConfig(config);
Expand Down
13 changes: 5 additions & 8 deletions src/agent/open-canvas/nodes/generatePath.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ import {
} from "../prompts";
import { OpenCanvasGraphAnnotation } from "../state";
import { z } from "zod";
import {
formatArtifactContentWithTemplate,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { formatArtifactContentWithTemplate, getModelConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig } from "../../utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";

/**
Expand Down Expand Up @@ -93,11 +90,11 @@ export const generatePath = async (
? "rewriteArtifact"
: "generateArtifact";

const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const model = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const model = await initChatModelWithConfig(modelName, {
temperature: 0,
modelProvider,
azureConfig,
});
const modelWithTool = model.withStructuredOutput(
z.object({
Expand Down
9 changes: 4 additions & 5 deletions src/agent/open-canvas/nodes/replyToGeneralInput.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig, getModelConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatArtifactContentWithTemplate,
formatReflections,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
Expand All @@ -18,11 +17,11 @@ export const replyToGeneralInput = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
modelProvider,
azureConfig,
});

const prompt = `You are an AI assistant tasked with responding to the users question.
Expand Down
21 changes: 12 additions & 9 deletions src/agent/open-canvas/nodes/rewriteArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
ensureStoreInConfig,
formatArtifactContent,
formatReflections,
getModelNameAndProviderFromConfig,
getModelConfig,
} from "../../utils";
import {
ArtifactCodeV3,
Expand All @@ -24,7 +24,7 @@ import {
isArtifactCodeContent,
isArtifactMarkdownContent,
} from "../../../lib/artifact_content_types";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig } from "../../utils";

export const rewriteArtifact = async (
state: typeof OpenCanvasGraphAnnotation.State,
Expand All @@ -51,12 +51,15 @@ export const rewriteArtifact = async (
"The language of the code artifact. This should be populated with the programming language if the user is requesting code to be written, or 'other', in all other cases."
),
});
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);

const { modelName, modelProvider, azureConfig } = getModelConfig(config);

// Then bind tools to create toolCallingModel
const toolCallingModel = (
await initChatModel(modelName, {
await initChatModelWithConfig(modelName, {
temperature: 0,
modelProvider,
azureConfig,
})
)
.bindTools(
Expand All @@ -71,14 +74,14 @@ export const rewriteArtifact = async (
)
.withConfig({ runName: "optionally_update_artifact_meta" });

// Initialize another instance for the second model
const smallModelWithConfig = (
await initChatModel(modelName, {
await initChatModelWithConfig(modelName, {
temperature: 0,
modelProvider,
azureConfig,
})
).withConfig({
runName: "rewrite_artifact_model_call",
});
).withConfig({ runName: "rewrite_artifact_model_call" });

const store = ensureStoreInConfig(config);
const assistantId = config.configurable?.assistant_id;
Expand Down
14 changes: 5 additions & 9 deletions src/agent/open-canvas/nodes/rewriteArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig, getModelConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { isArtifactMarkdownContent } from "../../../lib/artifact_content_types";
import { ArtifactV3, Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { ensureStoreInConfig, formatReflections } from "../../utils";
import {
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
CHANGE_ARTIFACT_LANGUAGE_PROMPT,
Expand All @@ -21,11 +17,11 @@ export const rewriteArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
modelProvider,
azureConfig,
});

const store = ensureStoreInConfig(config);
Expand Down
10 changes: 5 additions & 5 deletions src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getModelNameAndProviderFromConfig } from "@/agent/utils";
import { getModelConfig } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { initChatModelWithConfig } from "../../utils";
import { getArtifactContent } from "../../../contexts/utils";
import { isArtifactCodeContent } from "../../../lib/artifact_content_types";
import { ArtifactCodeV3, ArtifactV3 } from "../../../types";
Expand All @@ -16,11 +16,11 @@ export const rewriteCodeArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
const { modelName, modelProvider, azureConfig } = getModelConfig(config);
const smallModel = await initChatModelWithConfig(modelName, {
temperature: 0.5,
modelProvider,
azureConfig,
});

const currentArtifactContent = state.artifact
Expand Down

0 comments on commit 0019011

Please sign in to comment.