Skip to content

Commit

Permalink
genai[patch]: support max_tokens init arg (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 4, 2024
1 parent 9ebd8c1 commit f7d96ef
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
7 changes: 5 additions & 2 deletions libs/genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_google_genai._enums import (
Expand Down Expand Up @@ -139,7 +139,7 @@ class _BaseGoogleGenerativeAI(BaseModel):
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
Expand Down Expand Up @@ -216,6 +216,9 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
"""

client: Any = None #: :meta private:
model_config = ConfigDict(
populate_by_name=True,
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions libs/genai/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
llm = GoogleGenerativeAI(max_tokens=10, model=model_name)
else:
llm = GoogleGenerativeAI(max_output_tokens=10) # type: ignore[call-arg]
llm = GoogleGenerativeAI(max_tokens=10) # type: ignore[call-arg]
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
Expand Down
2 changes: 2 additions & 0 deletions libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'lc': 1,
'type': 'secret',
}),
'max_output_tokens': 100,
'max_retries': 2,
'model': 'models/gemini-1.0-pro-001',
'n': 1,
Expand Down Expand Up @@ -44,6 +45,7 @@
'lc': 1,
'type': 'secret',
}),
'max_output_tokens': 100,
'max_retries': 2,
'model': 'models/gemini-1.5-pro-001',
'n': 1,
Expand Down

0 comments on commit f7d96ef

Please sign in to comment.