Skip to content

Commit

Permalink
Merge branch 'staging' into brace/use-anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Oct 28, 2024
2 parents 4279021 + a306ac2 commit f829417
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ LANGCHAIN_API_KEY=
ANTHROPIC_API_KEY=
# OpenAI used for content generation
OPENAI_API_KEY=
# Optional, only required if using `Gemini 1.5 Flash` as the model.
GOOGLE_API_KEY=

# LangGraph Deployment, or local development server via LangGraph Studio.
# If running locally, this URL should be set in the `constants.ts` file.
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"@langchain/anthropic": "^0.3.6",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.14",
"@langchain/google-genai": "^0.1.0",
"@langchain/langgraph": "^0.2.18",
"@langchain/langgraph-sdk": "^0.0.17",
"@langchain/openai": "^0.3.11",
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/customAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
CUSTOM_QUICK_ACTION_ARTIFACT_CONTENT_PROMPT,
Expand All @@ -39,9 +39,11 @@ export const customAction = async (
throw new Error("No custom quick action ID found.");
}

const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/generateFollowup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { FOLLOWUP_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
Expand All @@ -18,10 +18,12 @@ export const generateFollowup = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
maxTokens: 250,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/respondToQuery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
ensureStoreInConfig,
formatArtifactContentWithTemplate,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import { CURRENT_ARTIFACT_PROMPT, NO_ARTIFACT_PROMPT } from "../prompts";
import { OpenCanvasGraphAnnotation, OpenCanvasGraphReturnType } from "../state";
Expand All @@ -18,9 +18,11 @@ export const respondToQuery = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const prompt = `You are an AI assistant tasked with responding to the users question.
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
ensureStoreInConfig,
formatArtifactContent,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
ArtifactCodeV3,
Expand Down Expand Up @@ -52,9 +52,11 @@ export const rewriteArtifact = async (
"The programming language of the code artifact. ONLY update this if the user is making a request which changes the programming language of the code artifact, or is asking for a code artifact to be generated."
),
});
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const model = await initChatModel(modelName, {
temperature: 0,
modelProvider,
});
const toolCallingModel = model
.bindTools(
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteArtifactTheme.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ArtifactV3, Reflections } from "../../../types";
import {
ensureStoreInConfig,
formatReflections,
getModelNameFromConfig,
getModelNameAndProviderFromConfig,
} from "../../utils";
import {
ADD_EMOJIS_TO_ARTIFACT_PROMPT,
Expand All @@ -21,9 +21,11 @@ export const rewriteArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const store = ensureStoreInConfig(config);
Expand Down
6 changes: 4 additions & 2 deletions src/agent/open-canvas/nodes/rewriteCodeArtifactTheme.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getModelNameFromConfig } from "@/agent/utils";
import { getModelNameAndProviderFromConfig } from "@/agent/utils";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { initChatModel } from "langchain/chat_models/universal";
import { getArtifactContent } from "../../../hooks/use-graph/utils";
Expand All @@ -16,9 +16,11 @@ export const rewriteCodeArtifactTheme = async (
state: typeof OpenCanvasGraphAnnotation.State,
config: LangGraphRunnableConfig
): Promise<OpenCanvasGraphReturnType> => {
const modelName = getModelNameFromConfig(config);
const { modelName, modelProvider } =
getModelNameAndProviderFromConfig(config);
const smallModel = await initChatModel(modelName, {
temperature: 0.5,
modelProvider,
});

const currentArtifactContent = state.artifact
Expand Down
35 changes: 35 additions & 0 deletions src/agent/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,41 @@ export const formatArtifactContentWithTemplate = (
);
};

export const getModelNameAndProviderFromConfig = (
config: LangGraphRunnableConfig
): { modelName: string; modelProvider: string } => {
const customModelName = config.configurable?.customModelName as string;
if (!customModelName) {
throw new Error("Model name is missing in config.");
}
if (customModelName.includes("gpt-")) {
return {
modelName: customModelName,
modelProvider: "openai",
};
}
if (customModelName.includes("claude-")) {
return {
modelName: customModelName,
modelProvider: "anthropic",
};
}
if (customModelName.includes("fireworks/")) {
return {
modelName: customModelName,
modelProvider: "fireworks",
};
}
if (customModelName.includes("gemini-")) {
return {
modelName: customModelName,
modelProvider: "google-genai",
};
}

throw new Error("Unknown model provider");
};

export const getModelNameFromConfig = (
config: LangGraphRunnableConfig
): string => {
Expand Down
8 changes: 7 additions & 1 deletion src/components/ModelSelector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ import {
ANTHROPIC_MODELS,
OPENAI_MODELS,
FIREWORKS_MODELS,
GEMINI_MODELS,
} from "@/constants";

const allModels = [...ANTHROPIC_MODELS, ...OPENAI_MODELS, ...FIREWORKS_MODELS];
const allModels = [
...ANTHROPIC_MODELS,
...OPENAI_MODELS,
...FIREWORKS_MODELS,
...GEMINI_MODELS,
];

const modelNameToLabel = (modelName: ALL_MODEL_NAMES) => {
const model = allModels.find((m) => m.name === modelName);
Expand Down
13 changes: 11 additions & 2 deletions src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export const LANGGRAPH_API_URL =
process.env.LANGGRAPH_API_URL ?? "http://localhost:49903";
process.env.LANGGRAPH_API_URL ?? "http://localhost:54790";
// v2 is tied to the 'open-canvas-prod' deployment.
export const ASSISTANT_ID_COOKIE = "oc_assistant_id_v2";
// export const ASSISTANT_ID_COOKIE = "oc_assistant_id";
Expand Down Expand Up @@ -43,11 +43,20 @@ export const FIREWORKS_MODELS = [
label: "Fireworks Llama 8B",
},
];

export const GEMINI_MODELS = [
{
name: "gemini-1.5-flash",
label: "Gemini 1.5 Flash",
},
];
export const DEFAULT_MODEL_NAME: ALL_MODEL_NAMES = "claude-3-haiku-20240307";
export type OPENAI_MODEL_NAMES = (typeof OPENAI_MODELS)[number]["name"];
export type ANTHROPIC_MODEL_NAMES = (typeof ANTHROPIC_MODELS)[number]["name"];
export type FIREWORKS_MODEL_NAMES = (typeof FIREWORKS_MODELS)[number]["name"];
export type GEMINI_MODEL_NAMES = (typeof GEMINI_MODELS)[number]["name"];
export type ALL_MODEL_NAMES =
| OPENAI_MODEL_NAMES
| ANTHROPIC_MODEL_NAMES
| FIREWORKS_MODEL_NAMES;
| FIREWORKS_MODEL_NAMES
| GEMINI_MODEL_NAMES;
15 changes: 13 additions & 2 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@
resolved "https://registry.yarnpkg.com/@floating-ui/utils/-/utils-0.2.8.tgz#21a907684723bbbaa5f0974cf7730bd797eb8e62"
integrity sha512-kym7SodPp8/wloecOpcmSnWJsK7M0E5Wg8UcFA+uO4B9s5d0ywXOEro/8HM9x0rW+TljRzul/14UYz3TleT3ig==

"@google/generative-ai@^0.7.0":
version "0.7.1"
resolved "https://registry.yarnpkg.com/@google/generative-ai/-/generative-ai-0.7.1.tgz#eb187c75080c0706245699dbc06816c830d8c6a7"
integrity sha512-WTjMLLYL/xfA5BW6xAycRPiAX7FNHKAxrid/ayqC1QMam0KAK0NbMeS9Lubw80gVg5xFMLE+H7pw4wdNzTOlxw==

"@hookform/resolvers@^3.6.0":
version "3.9.0"
resolved "https://registry.yarnpkg.com/@hookform/resolvers/-/resolvers-3.9.0.tgz#cf540ac21c6c0cd24a40cf53d8e6d64391fb753d"
Expand Down Expand Up @@ -705,6 +710,14 @@
zod "^3.22.4"
zod-to-json-schema "^3.22.3"

"@langchain/google-genai@^0.1.0":
version "0.1.0"
resolved "https://registry.yarnpkg.com/@langchain/google-genai/-/google-genai-0.1.0.tgz#89552873210d72a5834de20fcbef3e6753283344"
integrity sha512-6rIba77zJVMj+048tLfkCBrkFbfAMiT+AfLEsu5s+CFoFmXMiI/dbKeDL4vhUWrJVb9uL4ZZyrnl0nKxyEKYgA==
dependencies:
"@google/generative-ai" "^0.7.0"
zod-to-json-schema "^3.22.4"

"@langchain/langgraph-checkpoint@~0.0.10":
version "0.0.11"
resolved "https://registry.yarnpkg.com/@langchain/langgraph-checkpoint/-/langgraph-checkpoint-0.0.11.tgz#65c40bc175faca98ed0901df9e76682585710e8d"
Expand Down Expand Up @@ -7101,7 +7114,6 @@ streamsearch@^1.1.0:
integrity sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==

"string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0:
name string-width-cjs
version "4.2.3"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010"
integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==
Expand Down Expand Up @@ -7190,7 +7202,6 @@ stringify-entities@^4.0.0:
character-entities-legacy "^3.0.0"

"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1:
name strip-ansi-cjs
version "6.0.1"
resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9"
integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==
Expand Down

0 comments on commit f829417

Please sign in to comment.