Skip to content

Commit

Permalink
Add Amazon Bedrock LLM endpoint integration (#1)
Browse files Browse the repository at this point in the history
A contribution from Prompt Security

* Kept it as close to the Python implementation as possible
* Followed the guidelines from https://github.com/hwchase17/langchainjs/blob/main/CONTRIBUTING.md and https://github.com/hwchase17/langchainjs/blob/main/.github/contributing/INTEGRATIONS.md
* Supplied with unit test coverage
* Added documentation
  • Loading branch information
vitaly-ps authored Aug 8, 2023
1 parent a256777 commit 59cbe71
Show file tree
Hide file tree
Showing 12 changed files with 1,228 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/extras/modules/model_io/models/llms/integrations/bedrock.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Bedrock

>[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case.
## Installation and Setup

Install the underlying library from AWS
```bash npm2yarn
npm install aws-sigv4-fetch
```

# LLM example usage
```
import CodeBlock from "@theme/CodeBlock";
import BedrockExample from "@examples/models/llm/bedrock.ts";
```
8 changes: 8 additions & 0 deletions examples/src/llms/bedrock.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { Bedrock } from "langchain/llms/bedrock";

async function test() {
const model = new Bedrock({model: "bedrock-model-name", regionName: "aws-region"});
const res = await model.call("Question: What would be a good company name a company that makes colorful socks?\nAnswer:");
console.log(res);
}
test();
1 change: 1 addition & 0 deletions examples/src/models/llm/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// TODO:: add content here
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ llms/googlepalm.d.ts
llms/sagemaker_endpoint.cjs
llms/sagemaker_endpoint.js
llms/sagemaker_endpoint.d.ts
llms/bedrock.cjs
llms/bedrock.js
llms/bedrock.d.ts
prompts.cjs
prompts.js
prompts.d.ts
Expand Down
13 changes: 13 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@
"llms/sagemaker_endpoint.cjs",
"llms/sagemaker_endpoint.js",
"llms/sagemaker_endpoint.d.ts",
"llms/bedrock.cjs",
"llms/bedrock.js",
"llms/bedrock.d.ts",
"prompts.cjs",
"prompts.js",
"prompts.d.ts",
Expand Down Expand Up @@ -583,6 +586,7 @@
"@xata.io/client": "^0.25.1",
"@zilliz/milvus2-sdk-node": ">=2.2.7",
"apify-client": "^2.7.1",
"aws-sigv4-fetch": "^2.1.1",
"axios": "^0.26.0",
"cheerio": "^1.0.0-rc.12",
"chromadb": "^1.5.3",
Expand Down Expand Up @@ -665,6 +669,7 @@
"@xata.io/client": "^0.25.1",
"@zilliz/milvus2-sdk-node": ">=2.2.7",
"apify-client": "^2.7.1",
"aws-sigv4-fetch": "^2.1.1",
"axios": "*",
"cheerio": "^1.0.0-rc.12",
"chromadb": "^1.5.3",
Expand Down Expand Up @@ -789,6 +794,9 @@
"apify-client": {
"optional": true
},
"aws-sigv4-fetch": {
"optional": true
},
"axios": {
"optional": true
},
Expand Down Expand Up @@ -1141,6 +1149,11 @@
"import": "./llms/sagemaker_endpoint.js",
"require": "./llms/sagemaker_endpoint.cjs"
},
"./llms/bedrock": {
"types": "./llms/bedrock.d.ts",
"import": "./llms/bedrock.js",
"require": "./llms/bedrock.cjs"
},
"./prompts": {
"types": "./prompts.d.ts",
"import": "./prompts.js",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const entrypoints = {
"llms/googlevertexai": "llms/googlevertexai",
"llms/googlepalm": "llms/googlepalm",
"llms/sagemaker_endpoint": "llms/sagemaker_endpoint",
"llms/bedrock": "llms/bedrock",
// prompts
prompts: "prompts/index",
"prompts/load": "prompts/load",
Expand Down Expand Up @@ -243,6 +244,7 @@ const requiresOptionalDependency = [
"llms/hf",
"llms/replicate",
"llms/sagemaker_endpoint",
"llms/bedrock",
"prompts/load",
"vectorstores/analyticdb",
"vectorstores/chroma",
Expand Down
154 changes: 154 additions & 0 deletions langchain/src/llms/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import { getEnvironmentVariable } from "../util/env.js";
import { LLM, BaseLLMParams } from "./base.js";

type Dict = { [key: string]: any };

class LLMInputOutputAdapter {
/** Adapter class to prepare the inputs from Langchain to a format
that LLM model expects. Also, provides a helper function to extract
the generated text from the model response. */

static prepareInput(provider: string, prompt: string): Dict {
const inputBody: Dict = {};

if (provider === "anthropic" || provider === "ai21") {
inputBody.prompt = prompt;
} else if (provider === "amazon") {
inputBody.inputText = prompt;
inputBody.textGenerationConfig = {};
} else {
inputBody.inputText = prompt;
}

if (provider === "anthropic" && !("max_tokens_to_sample" in inputBody)) {
inputBody.max_tokens_to_sample = 50;
}

return inputBody;
}

static prepareOutput(provider: string, responseBody: any): string {
if (provider === "anthropic") {
return responseBody.completion;
} else if (provider === "ai21") {
return responseBody.completions[0].data.text;
}
return responseBody.results[0].outputText;
}
}

/** Bedrock models.
To authenticate, the AWS client uses the following methods to automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to access the Bedrock service.
*/
export interface BedrockInput {
/** Model to use.
For example, "amazon.titan-tg1-large", this is equivalent to the modelId property in the list-foundation-models api.
*/
model: string;

/** The AWS region e.g. `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here.
*/
regionName?: string;

/** Temperature */
temperature?: number;

/** Max tokens */
maxTokens?: number;
}

export class Bedrock extends LLM implements BedrockInput {
model = "amazon.titan-tg1-large";

regionName?: string | undefined = undefined;

temperature?: number | undefined = undefined;

maxTokens?: number | undefined = undefined;

get lc_secrets(): { [key: string]: string } | undefined {
return {};
}

_llmType() {
return "bedrock";
}

constructor(fields?: Partial<BedrockInput> & BaseLLMParams) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
);
}
this.regionName =
fields?.regionName ?? getEnvironmentVariable("AWS_DEFAULT_REGION");
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
}

/** Call out to Bedrock service model.
Arguments:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
response = model.call("Tell me a joke.")
*/
async _call(prompt: string): Promise<string> {
const { createSignedFetcher } = await Bedrock.imports();

const signedFetcher = createSignedFetcher({
service: "bedrock",
region: this.regionName,
});

const url = `https://bedrock.${this.regionName}.amazonaws.com/model/${this.model}/invoke`;
const provider = this.model.split(".")[0];
const inputBody = LLMInputOutputAdapter.prepareInput(provider, prompt);

const response = await this.caller.call(
async () =>
await signedFetcher(url, {
method: "post",
body: JSON.stringify(inputBody),
headers: {
"Content-Type": "application/json",
accept: "application/json",
},
})
);

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${url}': got ${response.status} ${
response.statusText
}: ${await response.text()}`
);
}
const responseJson = await response.json();
const text = LLMInputOutputAdapter.prepareOutput(provider, responseJson);
return text;
}

/** @ignore */
static async imports(): Promise<{ createSignedFetcher: any }> {
try {
const { createSignedFetcher } = await import("aws-sigv4-fetch");
return { createSignedFetcher };
} catch (e) {
throw new Error(
"Please install a dependency for bedrock with, e.g. `yarn add aws-sigv4-fetch`"
);
}
}
}
Loading

0 comments on commit 59cbe71

Please sign in to comment.