Skip to content

Commit

Permalink
aws[patch]: support ChatBedrock(model=...) (#211)
Browse files Browse the repository at this point in the history
per langchain-ai/langchain#20085

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
  • Loading branch information
baskaryan and ccurme authored Sep 20, 2024
1 parent 7c20d8b commit 6ad78b7
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def lc_attributes(self) -> Dict[str, Any]:

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

def _get_ls_params(
Expand Down
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class BedrockBase(BaseLanguageModel, ABC):
not have the provider in them, e.g., custom and provisioned models that have an ARN
associated with them."""

model_id: str
model_id: str = Field(alias="model")
"""Id of the model to call, e.g., amazon.titan-text-express-v1, this is
equivalent to the modelId property in the list-foundation-models api. For custom and
provisioned models, an ARN value is expected."""
Expand Down Expand Up @@ -1065,6 +1065,7 @@ def _get_ls_params(

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

def _stream(
Expand Down
2 changes: 1 addition & 1 deletion libs/aws/tests/integration_tests/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_bedrock_llm() -> None:
llm = BedrockLLM(model_id="anthropic.claude-v2:1")
llm = BedrockLLM(model_id="anthropic.claude-v2:1") # type: ignore[call-arg]
response = llm.invoke("Hello")
assert isinstance(response, str)
assert len(response) > 0
8 changes: 6 additions & 2 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,16 @@ def test_anthropic_bind_tools_tool_choice() -> None:

def test_standard_tracing_params() -> None:
llm = ChatBedrock(model_id="foo", region_name="us-west-2") # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
expected = {
"ls_provider": "amazon_bedrock",
"ls_model_type": "chat",
"ls_model_name": "foo",
}
assert llm._get_ls_params() == expected

# Test initialization with `model` alias
llm = ChatBedrock(model="foo", region_name="us-west-2")
assert llm._get_ls_params() == expected

llm = ChatBedrock(
model_id="foo", model_kwargs={"temperature": 0.1}, region_name="us-west-2"
Expand Down

0 comments on commit 6ad78b7

Please sign in to comment.