Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grounding vertex AI Gemini #11545

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class Vertex(FunctionCallingLLM):
model: str = Field(description="The vertex model to use.")
temperature: float = Field(description="The temperature to use for sampling.")
max_tokens: int = Field(description="The maximum number of tokens to generate.")
datastore: Optional[str] = Field(
default=None, description="The datastore to use for grounding the model."
)
examples: Optional[Sequence[ChatMessage]] = Field(
description="Example messages for the chat model."
)
Expand All @@ -100,6 +103,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Any] = None,
datastore: Optional[str] = None,
examples: Optional[Sequence[ChatMessage]] = None,
temperature: float = 0.1,
max_tokens: int = 512,
Expand Down Expand Up @@ -153,6 +157,7 @@ def __init__(
super().__init__(
temperature=temperature,
max_tokens=max_tokens,
datastore=datastore,
additional_kwargs=additional_kwargs,
max_retries=max_retries,
model=model,
Expand Down Expand Up @@ -244,6 +249,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
is_gemini=self._is_gemini,
params=chat_params,
max_retries=self.max_retries,
datastore = self.datastore,
**params,
)

Expand Down Expand Up @@ -272,6 +278,7 @@ def complete(
prompt,
max_retries=self.max_retries,
is_gemini=self._is_gemini,
datastore = self.datastore,
**params,
)
return CompletionResponse(text=completion.text, raw=completion.__dict__)
Expand Down Expand Up @@ -306,6 +313,7 @@ def stream_chat(
prompt=question,
chat=True,
stream=True,
datastore = self.datastore,
is_gemini=self._is_gemini,
params=chat_params,
max_retries=self.max_retries,
Expand Down Expand Up @@ -341,6 +349,7 @@ def stream_complete(
stream=True,
is_gemini=self._is_gemini,
max_retries=self.max_retries,
datastore = self.datastore,
**params,
)

Expand Down Expand Up @@ -386,6 +395,7 @@ async def achat(
is_gemini=self._is_gemini,
params=chat_params,
max_retries=self.max_retries,
datastore = self.datastore,
**params,
)
##this is due to a bug in vertex AI we have to await twice
Expand Down Expand Up @@ -415,6 +425,7 @@ async def acomplete(
prompt=prompt,
max_retries=self.max_retries,
is_gemini=self._is_gemini,
datastore = self.datastore,
**params,
)
return CompletionResponse(text=completion.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def completion_with_retry(
chat: bool = False,
stream: bool = False,
is_gemini: bool = False,
datastore: Optional[str] = None,
params: Any = {},
**kwargs: Any,
) -> Any:
Expand All @@ -90,7 +91,14 @@ def _completion_with_retry(**kwargs: Any) -> Any:
tools = kwargs.pop("tools", None) if "tools" in kwargs else []
tools = to_gemini_tools(tools) if tools else []
generation_config = kwargs if kwargs else {}

if datastore:
from vertexai.preview.generative_models import Tool, grounding
tool = Tool.from_retrieval(
grounding.Retrieval(grounding.VertexAISearch(datastore=datastore))
)
return generation.send_message(
prompt, stream=stream, generation_config=generation_config, tools=[tool]
)
return generation.send_message(
prompt, stream=stream, tools=tools, generation_config=generation_config
)
Expand All @@ -116,6 +124,7 @@ async def acompletion_with_retry(
chat: bool = False,
is_gemini: bool = False,
params: Any = {},
datastore: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
Expand All @@ -130,6 +139,14 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
tools = kwargs.pop("tools", None) if "tools" in kwargs else []
tools = to_gemini_tools(tools) if tools else []
generation_config = kwargs if kwargs else {}
if datastore:
from vertexai.preview.generative_models import Tool, grounding
tool = Tool.from_retrieval(
grounding.Retrieval(grounding.VertexAISearch(datastore=datastore))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Do you mind giving me a quick rundown of what this does to API calls? Does it force the LLM to act as an agent with this single retrieval tool?

Does vertexai support tools for any LLM? This would be really nice for agents for example, similar to openais tools api

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Logan,

Thanks for your answer. What this does is, in case you have a Verte datastore defined (see this link) you can ground your model in the data store. This is incredibly convenient for RAG applications, as creating a datastore is very easy.

And it doesn't force the LLM to act as an agent. You are just giving the LLM your datastore as a resource. You can tell in the prompt to give an answer strictly using the datastore, to reduce hallucinations.

)
return await generation.send_message_async(
prompt, generation_config=generation_config, tools=[tool]
)
return await generation.send_message_async(
prompt, tools=tools, generation_config=generation_config
)
Expand Down