Skip to content

Add basic tools support. #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from __future__ import annotations

from collections.abc import Iterable
import dataclasses
import textwrap
from typing import Union

# pylint: disable=bad-continuation, line-too-long


from collections.abc import Iterable

from google.ai import generativelanguage as glm
from google.generativeai import client
from google.generativeai import string_utils
Expand Down Expand Up @@ -70,6 +70,7 @@
generation_config: Overrides for the model's generation config.
safety_settings: Overrides for the model's safety settings.
stream: If True, yield response chunks as they are generated.
tools: `glm.Tools` more info coming soon.
"""

_SEND_MESSAGE_ASYNC_DOC = """The async version of `ChatSession.send_message`."""
Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(
model_name: str = "gemini-m",
safety_settings: safety_types.SafetySettingOptions | None = None,
generation_config: generation_types.GenerationConfigType | None = None,
tools: content_types.ToolsType = None,
):
if "/" not in model_name:
model_name = "models/" + model_name
Expand All @@ -166,6 +168,8 @@ def __init__(
safety_settings, harm_category_set="new"
)
self._generation_config = generation_types.to_generation_config_dict(generation_config)
self._tools = content_types.to_tools(tools)

self._client = None
self._async_client = None

Expand Down Expand Up @@ -213,6 +217,7 @@ def _prepare_request(
contents=contents,
generation_config=merged_gc,
safety_settings=merged_ss,
tools=self._tools,
**kwargs,
)

Expand Down
15 changes: 15 additions & 0 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"ContentType",
"StrictContentType",
"ContentsType",
"ToolsType",
]


Expand Down Expand Up @@ -234,3 +235,17 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:

contents = [to_content(contents)]
return contents


ToolsType = Union[Iterable[glm.Tool], glm.Tool, dict[str, Any], None]


def to_tools(tools: ToolsType) -> list[glm.Tool]:
if tools is None:
return []
elif isinstance(tools, Mapping):
return [glm.Tool(tools)]
elif isinstance(tools, Iterable):
return [glm.Tool(t) for t in tools]
else:
return [glm.Tool(tools)]
42 changes: 42 additions & 0 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,48 @@ def test_img_to_contents(self, example):
self.assertEqual(blob.mime_type, "image/png")
self.assertStartsWith(blob.data, b"\x89PNG")

@parameterized.named_parameters(
[
"OneTool",
glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
name="datetime", description="Returns the current UTC date and time."
)
]
),
],
[
"ToolDict",
dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time.")
]
),
],
[
"ListOfTools",
[
glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
name="datetime",
description="Returns the current UTC date and time.",
)
]
)
],
],
)
def test_img_to_contents(self, tools):
tools = content_types.to_tools(tools)
expected = dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time.")
]
)
self.assertEqual(type(tools[0]).to_dict(tools[0]), expected)


if __name__ == "__main__":
absltest.main()
28 changes: 28 additions & 0 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,34 @@ def test_chat_streaming_unexpected_stop(self):
chat.rewind()
self.assertLen(chat.history, 0)

def test_tools(self):
tools = dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time.")
]
)
model = generative_models.GenerativeModel("gemini-mm-m", tools=tools)

self.responses["generate_content"] = [
simple_response("a"),
simple_response("b"),
]

response = model.generate_content("Hello")

chat = model.start_chat()
response = chat.send_message("Hello")

expect_tools = dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time.")
]
)

for obr in self.observed_requests:
self.assertLen(obr.tools, 1)
self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)

@parameterized.named_parameters(
[
"GenerateContentResponse",
Expand Down