Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2688848
Update deps for req (>= 0.5.1 is needed)
stevehodgkiss Jul 6, 2024
f14bb80
Bedrock support for Anthropic
stevehodgkiss Jul 6, 2024
6a4307b
Extract BedrockStreamDecoder from ChatAnthropic
stevehodgkiss Jul 6, 2024
d40eb76
Move function
stevehodgkiss Jul 6, 2024
bb4069c
Use same relevant_event? function as anthropic api
stevehodgkiss Jul 6, 2024
d292e2d
Extract BedrockConfig
stevehodgkiss Jul 6, 2024
184aec4
Extract module var for aws anthropic version
stevehodgkiss Jul 6, 2024
a670f1d
Merge remote-tracking branch 'origin/main' into add-bedrock-support
stevehodgkiss Jul 7, 2024
e58b654
Use same tests as anthropic on bedrock
stevehodgkiss Jul 7, 2024
a2b3c2f
Move config to setup
stevehodgkiss Jul 7, 2024
2a7b3ae
Move anthropic_version to bedrock config
stevehodgkiss Jul 7, 2024
e61330a
Rename function
stevehodgkiss Jul 7, 2024
4365e01
Pull bedrock url functions to BedrockConfig
stevehodgkiss Jul 7, 2024
25a0352
Consistent tag name for anthropic_bedrock
stevehodgkiss Jul 7, 2024
8a192b8
Pass through case where chunk.bytes isn't present in stream
stevehodgkiss Jul 7, 2024
fe6b18c
Pull aws_sigv4_opts into BedrockConfig
stevehodgkiss Jul 7, 2024
945ca40
Improve pattern matching on bedrock config
stevehodgkiss Jul 7, 2024
82ca04f
Handle bedrock http error messages
stevehodgkiss Jul 7, 2024
8fdd70b
Use Mimic & add tests around error cases
stevehodgkiss Jul 7, 2024
a3230cf
Merge remote-tracking branch 'origin/main' into add-bedrock-support
stevehodgkiss Jul 18, 2024
57d5704
Update req
stevehodgkiss Jul 18, 2024
4782ac9
Require latest req for aws path encoding fix + session token support
stevehodgkiss Jul 18, 2024
fb6bd81
Support session token if returned from credentials fn
stevehodgkiss Jul 18, 2024
761b18b
Simplify - pass a keyword list through to req sigv4 opts instead of w…
stevehodgkiss Jul 18, 2024
7dd76bc
Merge remote-tracking branch 'origin/main' into add-bedrock-support
stevehodgkiss Oct 24, 2024
4e43534
Update tests
stevehodgkiss Oct 24, 2024
4f8eb84
Remove duplicate tests
stevehodgkiss Oct 24, 2024
a4a0d8c
Add stub aws creds to github workflow
stevehodgkiss Oct 24, 2024
72cf1b0
Move commented out test back
stevehodgkiss Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ env:
OPENAI_API_KEY: invalid
ANTHROPIC_API_KEY: invalid
GOOGLE_API_KEY: invalid
AWS_ACCESS_KEY_ID: invalid
AWS_SECRET_ACCESS_KEY: invalid

permissions:
contents: read
Expand Down
93 changes: 77 additions & 16 deletions lib/chat_models/chat_anthropic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do
alias LangChain.FunctionParam
alias LangChain.Utils
alias LangChain.Callbacks
alias LangChain.Utils.BedrockStreamDecoder
alias LangChain.Utils.BedrockConfig

@behaviour ChatModel

Expand All @@ -67,6 +69,9 @@ defmodule LangChain.ChatModels.ChatAnthropic do
# API endpoint to use. Defaults to Anthropic's API
field :endpoint, :string, default: "https://api.anthropic.com/v1/messages"

# Configuration for AWS Bedrock. Configure this instead of endpoint & api_key if you want to use Bedrock.
embeds_one :bedrock, BedrockConfig

# API key for Anthropic. If not set, will use global api key. Allows for usage
# of a different API key per-call if desired. For instance, allowing a
# customer to provide their own.
Expand Down Expand Up @@ -131,19 +136,14 @@ defmodule LangChain.ChatModels.ChatAnthropic do
]
@required_fields [:endpoint, :model]

@spec get_api_key(t()) :: String.t()
defp get_api_key(%ChatAnthropic{api_key: api_key}) do
# if no API key is set default to `""` which will raise an error
api_key || Config.resolve(:anthropic_key, "")
end

@doc """
Setup a ChatAnthropic client configuration.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatAnthropic{}
|> cast(attrs, @create_fields)
|> cast_embed(:bedrock)
|> common_validation()
|> apply_action(:insert)
end
Expand Down Expand Up @@ -204,6 +204,15 @@ defmodule LangChain.ChatModels.ChatAnthropic do
|> Utils.conditionally_add_to_map(:max_tokens, anthropic.max_tokens)
|> Utils.conditionally_add_to_map(:top_p, anthropic.top_p)
|> Utils.conditionally_add_to_map(:top_k, anthropic.top_k)
|> maybe_transform_for_bedrock(anthropic.bedrock)
end

defp maybe_transform_for_bedrock(body, nil), do: body

defp maybe_transform_for_bedrock(body, %BedrockConfig{} = bedrock) do
body
|> Map.put(:anthropic_version, bedrock.anthropic_version)
|> Map.drop([:model, :stream])
end

defp get_tools_for_api(nil), do: []
Expand Down Expand Up @@ -287,13 +296,14 @@ defmodule LangChain.ChatModels.ChatAnthropic do
) do
req =
Req.new(
url: anthropic.endpoint,
url: url(anthropic),
json: for_api(anthropic, messages, tools),
headers: headers(get_api_key(anthropic), anthropic.api_version),
headers: headers(anthropic),
receive_timeout: anthropic.receive_timeout,
retry: :transient,
max_retries: 3,
retry_delay: fn attempt -> 300 * attempt end
retry_delay: fn attempt -> 300 * attempt end,
aws_sigv4: aws_sigv4_opts(anthropic.bedrock)
)

req
Expand Down Expand Up @@ -341,14 +351,19 @@ defmodule LangChain.ChatModels.ChatAnthropic do
retry_count
) do
Req.new(
url: anthropic.endpoint,
url: url(anthropic),
json: for_api(anthropic, messages, tools),
headers: headers(get_api_key(anthropic), anthropic.api_version),
receive_timeout: anthropic.receive_timeout
headers: headers(anthropic),
receive_timeout: anthropic.receive_timeout,
aws_sigv4: aws_sigv4_opts(anthropic.bedrock)
)
|> Req.post(
into:
Utils.handle_stream_fn(anthropic, &decode_stream/1, &do_process_response(anthropic, &1))
Utils.handle_stream_fn(
anthropic,
&decode_stream(anthropic, &1),
&do_process_response(anthropic, &1)
)
)
|> case do
{:ok, %Req.Response{body: data} = response} ->
Expand Down Expand Up @@ -379,16 +394,40 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end
end

defp headers(api_key, api_version) do
defp aws_sigv4_opts(nil), do: nil
defp aws_sigv4_opts(%BedrockConfig{} = bedrock), do: BedrockConfig.aws_sigv4_opts(bedrock)

@spec get_api_key(binary() | nil) :: String.t()
defp get_api_key(api_key) do
# if no API key is set default to `""` which will raise an error
api_key || Config.resolve(:anthropic_key, "")
end

defp headers(%ChatAnthropic{bedrock: nil, api_key: api_key, api_version: api_version}) do
%{
"x-api-key" => api_key,
"x-api-key" => get_api_key(api_key),
"content-type" => "application/json",
"anthropic-version" => api_version,
# https://docs.anthropic.com/claude/docs/tool-use - requires this header during beta
"anthropic-beta" => "tools-2024-04-04"
}
end

defp headers(%ChatAnthropic{bedrock: %BedrockConfig{}}) do
%{
"content-type" => "application/json",
"accept" => "application/json"
}
end

defp url(%ChatAnthropic{bedrock: nil} = anthropic) do
anthropic.endpoint
end

defp url(%ChatAnthropic{bedrock: %BedrockConfig{} = bedrock, stream: stream} = anthropic) do
BedrockConfig.url(bedrock, model: anthropic.model, stream: stream)
end

# Parse a new message response
@doc false
@spec do_process_response(t(), data :: %{String.t() => any()} | {:error, any()}) ::
Expand Down Expand Up @@ -513,6 +552,16 @@ defmodule LangChain.ChatModels.ChatAnthropic do
{:error, error_message}
end

def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{"message" => message}) do
{:error, "Received error from API: #{message}"}
end

def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{
bedrock_exception: exceptions
}) do
{:error, "Stream exception received: #{inspect(exceptions)}"}
end

def do_process_response(_model, other) do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, "Unexpected response"}
Expand Down Expand Up @@ -583,7 +632,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end

@doc false
def decode_stream({chunk, buffer}) do
def decode_stream(%ChatAnthropic{bedrock: nil}, {chunk, buffer}) do
# Combine the incoming data with the buffered incomplete data
combined_data = buffer <> chunk
# Split data by double newline to find complete messages
Expand Down Expand Up @@ -651,6 +700,18 @@ defmodule LangChain.ChatModels.ChatAnthropic do
# assumed the response is JSON. Return as-is
defp extract_data(json), do: json

@doc false
def decode_stream(%ChatAnthropic{bedrock: %BedrockConfig{}}, {chunk, buffer}, chunks \\ []) do
{chunks, remaining} = BedrockStreamDecoder.decode_stream({chunk, buffer}, chunks)

chunks =
Enum.filter(chunks, fn chunk ->
Map.has_key?(chunk, :bedrock_exception) || relevant_event?("event: #{chunk["type"]}\n")
end)

{chunks, remaining}
end

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
"""
Expand Down
43 changes: 43 additions & 0 deletions lib/utils/aws_eventstream_decoder.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
defmodule LangChain.Utils.AwsEventstreamDecoder do
@moduledoc """
Decodes AWS messages in the application/vnd.amazon.eventstream content-type.
Ignores the headers because on Bedrock it's the same content type, event type & message type headers in every message.
"""

def decode(<<
message_length::32,
headers_length::32,
prelude_checksum::32,
headers::binary-size(headers_length),
body::binary-size(message_length - headers_length - 16),
message_checksum::32,
rest::bitstring
>>) do
message_without_checksum =
<<message_length::32, headers_length::32, prelude_checksum::32,
headers::binary-size(headers_length),
body::binary-size(message_length - headers_length - 16)>>

with :ok <-
verify_checksum(<<message_length::32, headers_length::32>>, prelude_checksum, :prelude),
:ok <- verify_checksum(message_without_checksum, message_checksum, :message) do
{:ok, body, rest}
end
end

def decode(<<message_length::32, _message::bitstring>> = data) do
{:incomplete_message, "Expected message length #{message_length} but got #{byte_size(data)}"}
end

def decode(_) do
{:error, "Unable to decode message"}
end

defp verify_checksum(data, checksum, part) do
if :erlang.crc32(data) == checksum do
:ok
else
{:error, "Checksum mismatch for #{part}"}
end
end
end
36 changes: 36 additions & 0 deletions lib/utils/bedrock_config.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defmodule LangChain.Utils.BedrockConfig do
@moduledoc """
Configuration for AWS Bedrock.
"""
use Ecto.Schema
import Ecto.Changeset

@primary_key false
embedded_schema do
# A function that returns a keyword list including access_key_id, secret_access_key, and optionally token.
# Used to configure Req's aws_sigv4 option.
field :credentials, :any, virtual: true
field :region, :string
field :anthropic_version, :string, default: "bedrock-2023-05-31"
end

def changeset(bedrock, attrs) do
bedrock
|> cast(attrs, [:credentials, :region, :anthropic_version])
|> validate_required([:credentials, :region, :anthropic_version])
end

def aws_sigv4_opts(%__MODULE__{} = bedrock) do
Keyword.merge(bedrock.credentials.(),
region: bedrock.region,
service: :bedrock
)
end

def url(%__MODULE__{region: region}, model: model, stream: stream) do
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{model}/#{action(stream: stream)}"
end

defp action(stream: true), do: "invoke-with-response-stream"
defp action(stream: false), do: "invoke"
end
77 changes: 77 additions & 0 deletions lib/utils/bedrock_stream_decoder.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
defmodule LangChain.Utils.BedrockStreamDecoder do
alias LangChain.Utils.AwsEventstreamDecoder
require Logger

def decode_stream({chunk, buffer}, chunks \\ []) do
combined_data = buffer <> chunk

case decode_chunk(combined_data) do
{:ok, chunk, remaining} ->
chunks = [chunk | chunks]
finish_or_decode_remaining(chunks, remaining)

{:incomplete_message, _} ->
{chunks, combined_data}

{:exception_response, response, remaining} ->
chunks = [response | chunks]
finish_or_decode_remaining(chunks, remaining)

{:error, error} ->
Logger.error("Failed to decode Bedrock chunk: #{inspect(error)}")
{chunks, combined_data}
end
end

defp finish_or_decode_remaining(chunks, remaining) when byte_size(remaining) > 0 do
decode_stream({"", remaining}, chunks)
end

defp finish_or_decode_remaining(chunks, remaining) do
{chunks, remaining}
end

defp decode_chunk(chunk) do
with {:ok, decoded_message, remaining} <- AwsEventstreamDecoder.decode(chunk),
{:ok, response_json} <- decode_json(decoded_message),
{:ok, bytes} <- get_bytes(response_json, remaining),
{:ok, json} <- decode_base64(bytes),
{:ok, payload} <- decode_json(json) do
{:ok, payload, remaining}
end
end

defp decode_json(data) do
case Jason.decode(data) do
{:ok, json} ->
{:ok, json}

{:error, error} ->
{:error, "Unable to decode JSON: #{inspect(error)}"}
end
end

defp decode_base64(bytes) do
case Base.decode64(bytes) do
{:ok, bytes} ->
{:ok, bytes}

:error ->
{:error, "Unable to decode base64 \"bytes\" from Bedrock response"}
end
end

defp get_bytes(%{"bytes" => bytes}, _remaining) do
{:ok, bytes}
end

# bytes is likely missing from the response in exception cases
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelWithResponseStream.html
defp get_bytes(response, remaining) do
Logger.debug("Bedrock response is an exception: #{inspect(response)}")
exception_message = Map.keys(response) |> Enum.join(", ")
# Make it easier to match on this pattern in process_data fns
response = Map.put(response, :bedrock_exception, exception_message)
{:exception_response, response, remaining}
end
end
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ defmodule LangChain.MixProject do
[
{:ecto, "~> 3.10 or ~> 3.11"},
{:gettext, "~> 0.20"},
{:req, ">= 0.5.0"},
{:req, ">= 0.5.2"},
{:abacus, "~> 2.1.0"},
{:nx, ">= 0.7.0", optional: true},
{:ex_doc, "~> 0.34", only: :dev, runtime: false},
Expand Down
Loading
Loading