Skip to content

Commit

Permalink
community[minor]: Add support for bedrock guardrails and trace (#5631)
Browse files Browse the repository at this point in the history
* add support for bedrock guardrails and trace

* updated bedrock tests for guardrails

* resolved guardrail and trace format errors, also resolved type errors for fetchFn

* updated chatbedrock test with guardrails

* Format and revert call option changes

* Lint

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
QuinnGT and jacoblee93 authored Jun 4, 2024
1 parent 294f600 commit 0744255
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 73 deletions.
121 changes: 112 additions & 9 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,92 @@ export function convertMessagesToPrompt(
* Services (AWS). It uses AWS credentials for authentication and can be
* configured with various parameters such as the model to use, the AWS
* region, and the maximum number of tokens to generate.
*
* The `BedrockChat` class supports both synchronous and asynchronous interactions with the model,
* allowing for streaming responses and handling new token callbacks. It can be configured with
* optional parameters like temperature, stop sequences, and guardrail settings for enhanced control
* over the generated responses.
*
* @example
* ```typescript
* const model = new BedrockChat({
* model: "anthropic.claude-v2",
* region: "us-east-1",
* });
* const res = await model.invoke([{ content: "Tell me a joke" }]);
* console.log(res);
* import { BedrockChat } from 'path-to-your-bedrock-chat-module';
* import { HumanMessage } from '@langchain/core/messages';
*
* async function run() {
* // Instantiate the BedrockChat model with the desired configuration
* const model = new BedrockChat({
* model: "anthropic.claude-v2",
* region: "us-east-1",
* credentials: {
* accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
* secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
* },
* maxTokens: 150,
* temperature: 0.7,
* stopSequences: ["\n", " Human:", " Assistant:"],
* streaming: false,
* trace: "ENABLED",
* guardrailIdentifier: "your-guardrail-id",
* guardrailVersion: "1.0",
* guardrailConfig: {
* tagSuffix: "example",
* streamProcessingMode: "SYNCHRONOUS",
* },
* });
*
* // Prepare the message to be sent to the model
* const message = new HumanMessage("Tell me a joke");
*
* // Invoke the model with the message
* const res = await model.invoke([message]);
*
* // Output the response from the model
* console.log(res);
* }
*
* run().catch(console.error);
* ```
*
* For streaming responses, use the following example:
* @example
* ```typescript
* import { BedrockChat } from 'path-to-your-bedrock-chat-module';
* import { HumanMessage } from '@langchain/core/messages';
*
* async function runStreaming() {
* // Instantiate the BedrockChat model with the desired configuration
* const model = new BedrockChat({
* model: "anthropic.claude-3-sonnet-20240229-v1:0",
* region: "us-east-1",
* credentials: {
* accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
* secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
* },
* maxTokens: 150,
* temperature: 0.7,
* stopSequences: ["\n", " Human:", " Assistant:"],
* streaming: true,
* trace: "ENABLED",
* guardrailIdentifier: "your-guardrail-id",
* guardrailVersion: "1.0",
* guardrailConfig: {
* tagSuffix: "example",
* streamProcessingMode: "SYNCHRONOUS",
* },
* });
*
* // Prepare the message to be sent to the model
* const message = new HumanMessage("Tell me a joke");
*
* // Stream the response from the model
* const stream = await model.stream([message]);
* for await (const chunk of stream) {
* // Output each chunk of the response
* console.log(chunk);
* }
* }
*
* runStreaming().catch(console.error);
* ```
*/
export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
Expand Down Expand Up @@ -135,6 +213,17 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

lc_serializable = true;

trace?: "ENABLED" | "DISABLED";

guardrailIdentifier = "";

guardrailVersion = "";

guardrailConfig?: {
tagSuffix: string;
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
};

get lc_aliases(): Record<string, string> {
return {
model: "model_id",
Expand Down Expand Up @@ -209,11 +298,16 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.modelKwargs = fields?.modelKwargs;
this.streaming = fields?.streaming ?? this.streaming;
this.usesMessagesApi = canUseMessagesApi(this.model);
this.trace = fields?.trace ?? this.trace;
this.guardrailVersion = fields?.guardrailVersion ?? this.guardrailVersion;
this.guardrailIdentifier =
fields?.guardrailIdentifier ?? this.guardrailIdentifier;
this.guardrailConfig = fields?.guardrailConfig;
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
options: Partial<BaseChatModelParams>,
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
Expand Down Expand Up @@ -285,7 +379,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs
this.modelKwargs,
this.guardrailConfig
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
Expand All @@ -294,7 +389,8 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
fields.bedrockMethod
fields.bedrockMethod,
this.guardrailConfig
);

const url = new URL(
Expand All @@ -313,6 +409,13 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
host: url.host,
accept: "application/json",
"content-type": "application/json",
...(this.trace &&
this.guardrailIdentifier &&
this.guardrailVersion && {
"X-Amzn-Bedrock-Trace": this.trace,
"X-Amzn-Bedrock-GuardrailIdentifier": this.guardrailIdentifier,
"X-Amzn-Bedrock-GuardrailVersion": this.guardrailVersion,
}),
},
});

Expand Down
Loading

0 comments on commit 0744255

Please sign in to comment.