Skip to content

Commit

Permalink
community[patch]:sparkllm standardize init args (#20194)
Browse files Browse the repository at this point in the history
Related to #20085
@baskaryan
  • Loading branch information
liugddx authored Apr 13, 2024
1 parent 7d7a08e commit 4be7ca7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion libs/community/langchain_community/chat_models/sparkllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,16 @@ def lc_secrets(self) -> Dict[str, str]:
spark_llm_domain: Optional[str] = None
spark_user_id: str = "lc_user"
streaming: bool = False
request_timeout: int = 30
request_timeout: int = Field(30, alias="timeout")
temperature: float = 0.5
top_k: int = 4
model_kwargs: Dict[str, Any] = Field(default_factory=dict)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
from langchain_community.chat_models.sparkllm import ChatSparkLLM


def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatSparkLLM(timeout=30),
ChatSparkLLM(request_timeout=30),
]:
assert model.request_timeout == 30


def test_chat_spark_llm() -> None:
chat = ChatSparkLLM()
message = HumanMessage(content="Hello")
Expand Down

0 comments on commit 4be7ca7

Please sign in to comment.