Skip to content

Commit

Permalink
community:yuan2[patch]: standardize init args (#21462)
Browse files Browse the repository at this point in the history
updated stop and request_timeout so they aliased to stop_sequences, and
timeout respectively. Added test that both continue to set the same
underlying attributes.

Related to
[20085](#20085)

Co-authored-by: ccurme <chester.curme@gmail.com>
  • Loading branch information
burd5 and ccurme authored Aug 23, 2024
1 parent bc557a5 commit f355a98
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
6 changes: 4 additions & 2 deletions libs/community/langchain_community/chat_models/yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ class ChatYuan2(BaseChatModel):
)
"""Base URL path for API requests, an OpenAI compatible API server."""

request_timeout: Optional[Union[float, Tuple[float, float]]] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to yuan2 completion API. Default is 600 seconds."""

max_retries: int = 6
Expand All @@ -111,7 +113,7 @@ class ChatYuan2(BaseChatModel):
top_p: Optional[float] = 0.9
"""The top-p value to use for sampling."""

stop: Optional[List[str]] = ["<eod>"]
stop: Optional[List[str]] = Field(default=["<eod>"], alias="stop_sequences")
"""A list of strings to stop generation when encountered."""

repeat_last_n: Optional[int] = 64
Expand Down
16 changes: 16 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ def test_yuan2_model_param() -> None:
assert chat.model_name == "foo"


@pytest.mark.requires("openai")
def test_yuan2_timeout_param() -> None:
chat = ChatYuan2(request_timeout=5) # type: ignore[call-arg]
assert chat.request_timeout == 5
chat = ChatYuan2(timeout=10) # type: ignore[call-arg]
assert chat.request_timeout == 10


@pytest.mark.requires("openai")
def test_yuan2_stop_sequences_param() -> None:
chat = ChatYuan2(stop=["<eod>"]) # type: ignore[call-arg]
assert chat.stop == ["<eod>"]
chat = ChatYuan2(stop_sequences=["<eod>"]) # type: ignore[call-arg]
assert chat.stop == ["<eod>"]


def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
Expand Down

0 comments on commit f355a98

Please sign in to comment.