Skip to content

Commit

Permalink
fix(community): Add support for Bedrock cross-region inference models (
Browse files Browse the repository at this point in the history
…#6682)

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
keremnalbant and jacoblee93 authored Sep 4, 2024
1 parent bc968ad commit 29c5b8c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 26 deletions.
60 changes: 46 additions & 14 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ type AnthropicTool = Record<string, unknown>;

type BedrockChatToolType = BindToolsInput | AnthropicTool;

const AWS_REGIONS = [
"us",
"sa",
"me",
"il",
"eu",
"cn",
"ca",
"ap",
"af",
"us-gov",
];

const ALLOWED_MODEL_PROVIDERS = [
"ai21",
"anthropic",
"amazon",
"cohere",
"meta",
"mistral",
];

const PRELUDE_TOTAL_LENGTH_BYTES = 4;

function convertOneMessageToText(
Expand Down Expand Up @@ -473,6 +495,8 @@ export class BedrockChat
{
model = "amazon.titan-tg1-large";

modelProvider: string;

region: string;

credentials: CredentialType;
Expand Down Expand Up @@ -545,17 +569,11 @@ export class BedrockChat
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = [
"ai21",
"anthropic",
"amazon",
"cohere",
"meta",
"mistral",
];
if (!allowedModels.includes(this.model.split(".")[0])) {
this.modelProvider = getModelProvider(this.model);

if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`
);
}
const region =
Expand Down Expand Up @@ -655,7 +673,7 @@ export class BedrockChat
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
const provider = this.modelProvider;
const response = await this._signedFetch(messages, options, {
bedrockMethod: "invoke",
endpointHost,
Expand Down Expand Up @@ -776,7 +794,7 @@ export class BedrockChat
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const provider = this.model.split(".")[0];
const provider = this.modelProvider;
const service = "bedrock-runtime";

const endpointHost =
Expand Down Expand Up @@ -956,7 +974,7 @@ export class BedrockChat
BaseMessageChunk,
this["ParsedCallOptions"]
> {
const provider = this.model.split(".")[0];
const provider = this.modelProvider;
if (provider !== "anthropic") {
throw new Error(
"Currently, tool calling through Bedrock is only supported for Anthropic models."
Expand All @@ -977,7 +995,7 @@ function isChatGenerationChunk(
}

function canUseMessagesApi(model: string): boolean {
const modelProviderName = model.split(".")[0];
const modelProviderName = getModelProvider(model);

if (
modelProviderName === "anthropic" &&
Expand All @@ -999,6 +1017,20 @@ function canUseMessagesApi(model: string): boolean {
return false;
}

function isInferenceModel(modelId: string): boolean {
const parts = modelId.split(".");
return AWS_REGIONS.some((region) => parts[0] === region);
}

function getModelProvider(modelId: string): string {
const parts = modelId.split(".");
if (isInferenceModel(modelId)) {
return parts[1];
} else {
return parts[0];
}
}

/**
* @deprecated Use `BedrockChat` instead.
*/
Expand Down
56 changes: 44 additions & 12 deletions libs/langchain-community/src/llms/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@ import {
} from "../../utils/bedrock/index.js";
import type { SerializedFields } from "../../load/map_keys.js";

const AWS_REGIONS = [
"us",
"sa",
"me",
"il",
"eu",
"cn",
"ca",
"ap",
"af",
"us-gov",
];

const ALLOWED_MODEL_PROVIDERS = [
"ai21",
"anthropic",
"amazon",
"cohere",
"meta",
"mistral",
];

const PRELUDE_TOTAL_LENGTH_BYTES = 4;

/**
Expand All @@ -31,6 +53,8 @@ const PRELUDE_TOTAL_LENGTH_BYTES = 4;
export class Bedrock extends LLM implements BaseBedrockInput {
model = "amazon.titan-tg1-large";

modelProvider: string;

region: string;

credentials: CredentialType;
Expand Down Expand Up @@ -84,17 +108,11 @@ export class Bedrock extends LLM implements BaseBedrockInput {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = [
"ai21",
"anthropic",
"amazon",
"cohere",
"meta",
"mistral",
];
if (!allowedModels.includes(this.model.split(".")[0])) {
this.modelProvider = getModelProvider(this.model);

if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`
);
}
const region =
Expand Down Expand Up @@ -141,7 +159,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
const provider = this.modelProvider;
if (this.streaming) {
const stream = this._streamResponseChunks(prompt, options, runManager);
let finalResult: GenerationChunk | undefined;
Expand Down Expand Up @@ -246,7 +264,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const provider = this.model.split(".")[0];
const provider = this.modelProvider;
const bedrockMethod =
provider === "anthropic" ||
provider === "cohere" ||
Expand Down Expand Up @@ -371,3 +389,17 @@ export class Bedrock extends LLM implements BaseBedrockInput {
};
}
}

function isInferenceModel(modelId: string): boolean {
const parts = modelId.split(".");
return AWS_REGIONS.some((region) => parts[0] === region);
}

function getModelProvider(modelId: string): string {
const parts = modelId.split(".");
if (isInferenceModel(modelId)) {
return parts[1];
} else {
return parts[0];
}
}
22 changes: 22 additions & 0 deletions libs/langchain-community/src/llms/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,25 @@ test("Test Bedrock LLM streaming: Claude-v2", async () => {
}
expect(chunks.length).toBeGreaterThan(1);
});

test("Test Bedrock LLM: Inference Models", async () => {
const region = process.env.BEDROCK_AWS_REGION!;
const model = "eu.anthropic.claude-3-5-sonnet-20240620-v1:0";
const prompt = "Human: What is your name?\n\nAssistant:";

const bedrock = new Bedrock({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

const res = await bedrock.invoke(prompt);
expect(typeof res).toBe("string");
// console.log(res);
});

0 comments on commit 29c5b8c

Please sign in to comment.