diff --git a/libs/community/langchain_community/chat_models/yuan2.py b/libs/community/langchain_community/chat_models/yuan2.py index df95e9902dcca..dc63eb83d246c 100644 --- a/libs/community/langchain_community/chat_models/yuan2.py +++ b/libs/community/langchain_community/chat_models/yuan2.py @@ -93,7 +93,9 @@ class ChatYuan2(BaseChatModel): ) """Base URL path for API requests, an OpenAI compatible API server.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = Field( + default=None, alias="timeout" + ) """Timeout for requests to yuan2 completion API. Default is 600 seconds.""" max_retries: int = 6 @@ -111,7 +113,7 @@ class ChatYuan2(BaseChatModel): top_p: Optional[float] = 0.9 """The top-p value to use for sampling.""" - stop: Optional[List[str]] = [""] + stop: Optional[List[str]] = Field(default=[""], alias="stop_sequences") """A list of strings to stop generation when encountered.""" repeat_last_n: Optional[int] = 64 diff --git a/libs/community/tests/unit_tests/chat_models/test_yuan2.py b/libs/community/tests/unit_tests/chat_models/test_yuan2.py index 74b2fb84cf530..683b2a013c775 100644 --- a/libs/community/tests/unit_tests/chat_models/test_yuan2.py +++ b/libs/community/tests/unit_tests/chat_models/test_yuan2.py @@ -22,6 +22,22 @@ def test_yuan2_model_param() -> None: assert chat.model_name == "foo" +@pytest.mark.requires("openai") +def test_yuan2_timeout_param() -> None: + chat = ChatYuan2(request_timeout=5) # type: ignore[call-arg] + assert chat.request_timeout == 5 + chat = ChatYuan2(timeout=10) # type: ignore[call-arg] + assert chat.request_timeout == 10 + + +@pytest.mark.requires("openai") +def test_yuan2_stop_sequences_param() -> None: + chat = ChatYuan2(stop=[""]) # type: ignore[call-arg] + assert chat.stop == [""] + chat = ChatYuan2(stop_sequences=[""]) # type: ignore[call-arg] + assert chat.stop == [""] + + def test__convert_message_to_dict_human() -> None: message = HumanMessage(content="foo") result = _convert_message_to_dict(message)