Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertexai[patch]: standardize model params #121

Merged
merged 9 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ class _VertexAIBase(BaseModel):
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
stop: Optional[List[str]] = Field(default=None, alias="stop_sequences")
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
model_name: Optional[str] = Field(default=None, alias="model")
"Underlying model name."

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator(pre=True)
def validate_params(cls, values: dict) -> dict:
if "model" in values and "model_name" not in values:
Expand All @@ -64,11 +69,11 @@ def validate_params(cls, values: dict) -> dict:

class _VertexAICommon(_VertexAIBase):
client_preview: Any = None #: :meta private:
model_name: str
model_name: str = Field(default=None, alias="model")
"Underlying model name."
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
"Token limit determines the maximum amount of text output from one prompt."
top_p: Optional[float] = None
"Tokens are selected from most probable to least until the sum of their "
Expand Down
16 changes: 14 additions & 2 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field
from langchain_core.runnables import Runnable, RunnablePassthrough
from vertexai.generative_models import ( # type: ignore
Candidate,
Expand Down Expand Up @@ -498,7 +498,7 @@ async def _completion_with_retry_inner(
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""

model_name: str = "chat-bison"
model_name: str = Field(default="chat-bison", alias="model")
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
tuned_model_name: Optional[str] = None
Expand All @@ -510,6 +510,18 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
setting this parameter to True is discouraged.
"""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True

@classmethod
def is_lc_serializable(self) -> bool:
return True
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict:

def __init__(
self,
model_name: str,
model_name: Optional[str] = None,
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
Expand All @@ -87,13 +87,14 @@ def __init__(
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
Expand Down
35 changes: 32 additions & 3 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
Expand Down Expand Up @@ -115,6 +115,17 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def _llm_type(self) -> str:
return "gemma_vertexai_model_garden"
Expand Down Expand Up @@ -178,9 +189,15 @@ class _GemmaLocalKaggleBase(_GemmaBase):

client: Any = None #: :meta private:
keras_backend: str = "jax"
model_name: str = "gemma_2b_en"
model_name: str = Field(default="gemma_2b_en", alias="model")
"""Gemma model name."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
Expand Down Expand Up @@ -212,6 +229,12 @@ def _get_params(self, **kwargs) -> Dict[str, Any]:
class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Only needed for typing."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

def _generate(
self,
prompts: List[str],
Expand All @@ -238,6 +261,12 @@ class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel):
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -268,7 +297,7 @@ class _GemmaLocalHFBase(_GemmaBase):
client: Any = None #: :meta private:
hf_access_token: str
cache_dir: Optional[str] = None
model_name: str = "gemma_2b_en"
model_name: str = Field(default="gemma_2b_en", alias="model")
"""Gemma model name."""

@root_validator()
Expand Down
15 changes: 13 additions & 2 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import Field, root_validator
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
GenerativeModel,
Expand Down Expand Up @@ -110,13 +110,24 @@ async def _acompletion_with_retry_inner(
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""

model_name: str = "text-bison"
model_name: str = Field(default="text-bison", alias="model")
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
"""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@classmethod
def is_lc_serializable(self) -> bool:
return True
Expand Down
16 changes: 13 additions & 3 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import Field, root_validator

from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon
Expand All @@ -34,6 +34,11 @@
class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

def _generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -92,9 +97,14 @@ async def _agenerate(

class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
async_client: Any = None #: :meta private:
model_name: Optional[str] = None # type: ignore[assignment]
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
"Underlying model name."
max_output_tokens: int = 1024
max_output_tokens: int = Field(default=1024, alias="max_tokens")

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ build-backend = "poetry.core.masonry.api"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
Expand Down
8 changes: 7 additions & 1 deletion libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
@pytest.mark.release
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings(model_name="textembedding-gecko@001")
for embeddings in [
VertexAIEmbeddings(
model_name="textembedding-gecko",
),
VertexAIEmbeddings(model="textembedding-gecko"),
]:
assert embeddings.model_name == "textembedding-gecko"


@pytest.mark.release
Expand Down
18 changes: 15 additions & 3 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,24 @@
)


def test_model_name() -> None:
def test_init() -> None:
for llm in [
ChatVertexAI(model_name="gemini-pro", project="test-project"),
ChatVertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
ChatVertexAI(
model_name="gemini-pro",
project="test-project",
max_output_tokens=10,
stop=["bar"],
),
ChatVertexAI(
model="gemini-pro",
project="test-project",
max_tokens=10,
stop_sequences=["bar"],
),
]:
assert llm.model_name == "gemini-pro"
assert llm.max_output_tokens == 10
assert llm.stop == ["bar"]


def test_tuned_model_name() -> None:
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/tests/unit_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

def test_model_name() -> None:
for llm in [
VertexAI(model_name="gemini-pro", project="test-project"),
VertexAI(model="gemini-pro", project="test-project"), # type: ignore[call-arg]
VertexAI(model_name="gemini-pro", project="test-project", max_output_tokens=10),
VertexAI(model="gemini-pro", project="test-project", max_tokens=10),
]:
assert llm.model_name == "gemini-pro"
assert llm.max_output_tokens == 10


def test_tuned_model_name() -> None:
Expand Down
Loading