Skip to content

Commit

Permalink
move to genai
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Oct 28, 2024
1 parent e122e44 commit 9791b32
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 212 deletions.
6 changes: 2 additions & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ LANGCHAIN_API_KEY=
ANTHROPIC_API_KEY=
# OpenAI used for content generation
OPENAI_API_KEY=

# In Google Cloud Enable GeminiAPI, Create a service account and download the key
# Set the path to the key file
GOOGLE_APPLICATION_CREDENTIALS=
# Optional, only required if using `Gemini 1.5 Flash` as the model.
GOOGLE_API_KEY=

# LangGraph Deployment, or local development server via LangGraph Studio.
# If running locally, this URL should be set in the `constants.ts` file.
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.14",
"@langchain/google-vertexai": "^0.1.0",
"@langchain/google-genai": "^0.1.0",
"@langchain/langgraph": "^0.2.18",
"@langchain/langgraph-sdk": "^0.0.17",
"@langchain/openai": "^0.3.11",
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/customAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
Expand All @@ -39,9 +39,11 @@ export const customAction = async (
throw new Error("No custom quick action ID found.");
}

const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/generateFollowup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
Expand All @@ -18,10 +18,12 @@ export const generateFollowup = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
maxTokens: 250,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/respondToQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
ensureStoreInConfig,
formatArtifactContentWithTemplate,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
Expand All @@ -18,9 +18,11 @@ export const respondToQuery = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const prompt = `You are an AI assistant tasked with responding to the users question.
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
ensureStoreInConfig,
formatArtifactContent,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
ArtifactCodeV3,
Expand Down Expand Up @@ -67,9 +67,11 @@ export const rewriteArtifact = async (
])
.withConfig({ runName: "optionally_update_artifact_meta" });

const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteArtifactTheme.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ArtifactV3, Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
Expand All @@ -21,9 +21,11 @@ export const rewriteArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getModelNameFromConfig } from "@/agent/utils";
import { getModelNameAndProviderFromConfig } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
Expand All @@ -16,9 +16,11 @@ export const rewriteCodeArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const currentArtifactContent = state.artifact
Expand Down
35 changes: 35 additions & 0 deletions src/agent/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,41 @@ export const formatArtifactContentWithTemplate = (
);
};

export const getModelNameAndProviderFromConfig = (
config: LangGraphRunnableConfig
): { modelName: string; modelProvider: string } => {
const customModelName = config.configurable?.customModelName as string;
if (!customModelName) {
throw new Error("Model name is missing in config.");
}
if (customModelName.includes("gpt-")) {
return {
modelName: customModelName,
modelProvider: "openai",
};
}
if (customModelName.includes("claude-")) {
return {
modelName: customModelName,
modelProvider: "anthropic",
};
}
if (customModelName.includes("fireworks/")) {
return {
modelName: customModelName,
modelProvider: "fireworks",
};
}
if (customModelName.includes("gemini-")) {
return {
modelName: customModelName,
modelProvider: "google-genai",
};
}

throw new Error("Unknown model provider");
};

export const getModelNameFromConfig = (
config: LangGraphRunnableConfig
): string => {
Expand Down
9 changes: 7 additions & 2 deletions src/components/ModelSelector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ import {
ANTHROPIC_MODELS,
OPENAI_MODELS,
FIREWORKS_MODELS,
GEMINI_MODELS
GEMINI_MODELS,
} from "@/constants";

const allModels = [...ANTHROPIC_MODELS, ...OPENAI_MODELS, ...FIREWORKS_MODELS, ...GEMINI_MODELS];
const allModels = [
...ANTHROPIC_MODELS,
...OPENAI_MODELS,
...FIREWORKS_MODELS,
...GEMINI_MODELS,
];

const modelNameToLabel = (modelName: ALL_MODEL_NAMES) => {
const model = allModels.find((m) => m.name === modelName);
Expand Down
4 changes: 2 additions & 2 deletions src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export const LANGGRAPH_API_URL =
process.env.LANGGRAPH_API_URL ?? "http://localhost:59811";
process.env.LANGGRAPH_API_URL ?? "http://localhost:54790";
// v2 is tied to the 'open-canvas-prod' deployment.
export const ASSISTANT_ID_COOKIE = "oc_assistant_id_v2";
// export const ASSISTANT_ID_COOKIE = "oc_assistant_id";
Expand Down Expand Up @@ -48,7 +48,7 @@ export const GEMINI_MODELS = [
{
name: "gemini-1.5-flash",
label: "Gemini 1.5 Flash",
}
},
];
export const DEFAULT_MODEL_NAME: ALL_MODEL_NAMES = "claude-3-haiku-20240307";
export type OPENAI_MODEL_NAMES = (typeof OPENAI_MODELS)[number]["name"];
Expand Down
Loading

0 comments on commit 9791b32

Please sign in to comment.