diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 475b8f0f..fab5fbaf 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -418,6 +418,7 @@ def lc_attributes(self) -> Dict[str, Any]: model_config = ConfigDict( extra="forbid", + populate_by_name=True, ) def _get_ls_params( diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index dad71de4..0425145a 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -514,7 +514,7 @@ class BedrockBase(BaseLanguageModel, ABC): not have the provider in them, e.g., custom and provisioned models that have an ARN associated with them.""" - model_id: str + model_id: str = Field(alias="model") """Id of the model to call, e.g., amazon.titan-text-express-v1, this is equivalent to the modelId property in the list-foundation-models api. For custom and provisioned models, an ARN value is expected.""" @@ -1065,6 +1065,7 @@ def _get_ls_params( model_config = ConfigDict( extra="forbid", + populate_by_name=True, ) def _stream( diff --git a/libs/aws/tests/integration_tests/llms/test_bedrock.py b/libs/aws/tests/integration_tests/llms/test_bedrock.py index 791f938b..b3c2961f 100644 --- a/libs/aws/tests/integration_tests/llms/test_bedrock.py +++ b/libs/aws/tests/integration_tests/llms/test_bedrock.py @@ -2,7 +2,7 @@ def test_bedrock_llm() -> None: - llm = BedrockLLM(model_id="anthropic.claude-v2:1") + llm = BedrockLLM(model_id="anthropic.claude-v2:1") # type: ignore[call-arg] response = llm.invoke("Hello") assert isinstance(response, str) assert len(response) > 0 diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index d2fbaf4e..2a869f36 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -405,12 +405,16 @@ def test_anthropic_bind_tools_tool_choice() -> None: def test_standard_tracing_params() -> None: llm = ChatBedrock(model_id="foo", region_name="us-west-2") # type: ignore[call-arg] - ls_params = llm._get_ls_params() - assert ls_params == { + expected = { "ls_provider": "amazon_bedrock", "ls_model_type": "chat", "ls_model_name": "foo", } + assert llm._get_ls_params() == expected + + # Test initialization with `model` alias + llm = ChatBedrock(model="foo", region_name="us-west-2") + assert llm._get_ls_params() == expected llm = ChatBedrock( model_id="foo", model_kwargs={"temperature": 0.1}, region_name="us-west-2"