Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-genai[patch]: match function call interface #17213

Merged
merged 9 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions libs/partners/google-genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

from typing import (
Dict,
List,
Type,
Union,
)

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.json_schema import dereference_refs

FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]

TYPE_ENUM = {
"string": 1,
"number": 2,
"integer": 3,
"boolean": 4,
"array": 5,
"object": 6,
}


def convert_to_genai_function_declarations(
function_calls: List[FunctionCallType],
) -> Dict:
function_declarations = []
for fc in function_calls:
function_declarations.append(_convert_to_genai_function(fc))
return {
"function_declarations": function_declarations,
}


def _convert_to_genai_function(fc: FunctionCallType) -> Dict:
"""
Produce

{
"name": "get_weather",
"description": "Determine weather in my location",
"parameters": {
"properties": {
"location": {
"description": "The city and state e.g. San Francisco, CA",
"type_": 1
},
"unit": { "enum": ["c", "f"], "type_": 1 }
},
"required": ["location"],
"type_": 6
}
}

"""
if isinstance(fc, BaseTool):
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
elif isinstance(fc, dict):
return {
**fc,
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in fc["parameters"]["properties"].items()
},
"required": fc["parameters"].get("required", []),
"type_": TYPE_ENUM[fc["parameters"]["type"]],
},
}
else:
raise ValueError(f"Unsupported function call type {fc}")


def _convert_tool_to_genai_function(tool: BaseTool) -> Dict:
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)

return {
"name": tool.name or schema["title"],
"description": tool.description or schema["description"],
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
}
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
"properties": {
"__arg1": {"type": "string"},
},
"required": ["__arg1"],
"type_": TYPE_ENUM["object"],
},
}


def _convert_pydantic_to_genai_function(
pydantic_model: Type[BaseModel],
) -> Dict:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)

return {
"name": schema["title"],
"description": schema.get("description", ""),
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
}
73 changes: 11 additions & 62 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import base64
import json
import logging
import os
from io import BytesIO
Expand Down Expand Up @@ -54,6 +55,9 @@
)

from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai._function_utils import (
convert_to_genai_function_declarations,
)
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI

IMAGE_TYPES: Tuple = ()
Expand Down Expand Up @@ -351,69 +355,14 @@ def _retrieve_function_call_response(
return {
"function_call": {
"name": fc.name,
"arguments": dict(fc.args.items()),
"arguments": json.dumps(
dict(fc.args.items())
), # dump to match other function calling llms for now
}
}
return None


def _convert_function_call_req(function_calls: Union[Dict, List[Dict]]) -> Dict:
function_declarations = []
if isinstance(function_calls, dict):
function_declarations.append(_convert_fc_type(function_calls))
else:
for fc in function_calls:
function_declarations.append(_convert_fc_type(fc))
return {
"function_declarations": function_declarations,
}


def _convert_fc_type(fc: Dict) -> Dict:
# type_: "Type"
# format_: str
# description: str
# nullable: bool
# enum: MutableSequence[str]
# items: "Schema"
# properties: MutableMapping[str, "Schema"]
# required: MutableSequence[str]
if "parameters" in fc:
fc["parameters"] = _convert_fc_type(fc["parameters"])
if "properties" in fc:
for k, v in fc["properties"].items():
fc["properties"][k] = _convert_fc_type(v)
if "type" in fc:
# STRING = 1
# NUMBER = 2
# INTEGER = 3
# BOOLEAN = 4
# ARRAY = 5
# OBJECT = 6
if fc["type"] == "string":
fc["type_"] = 1
elif fc["type"] == "number":
fc["type_"] = 2
elif fc["type"] == "integer":
fc["type_"] = 3
elif fc["type"] == "boolean":
fc["type_"] = 4
elif fc["type"] == "array":
fc["type_"] = 5
elif fc["type"] == "object":
fc["type_"] = 6
del fc["type"]
if "format" in fc:
fc["format_"] = fc["format"]
del fc["format"]

for k, v in fc.items():
if isinstance(v, dict):
fc[k] = _convert_fc_type(v)

return fc


def _parts_to_content(
parts: List[genai.types.PartType],
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
Expand Down Expand Up @@ -708,19 +657,19 @@ def _prepare_chat(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
cli = self.client
client = self.client
functions = kwargs.pop("functions", None)
if functions:
tools = _convert_function_call_req(functions)
cli = genai.GenerativeModel(model_name=self.model, tools=tools)
tools = convert_to_genai_function_declarations(functions)
client = genai.GenerativeModel(model_name=self.model, tools=tools)

params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = cli.start_chat(history=history)
chat = client.start_chat(history=history)
return params, chat, message

def get_num_tokens(self, text: str) -> int:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Test ChatGoogleGenerativeAI function call."""

import json

from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
)
Expand Down Expand Up @@ -29,6 +35,50 @@ def test_function_call() -> None:
assert res.additional_kwargs
assert "function_call" in res.additional_kwargs
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
arguments = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments, dict)
arguments_str = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments_str, str)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to string to match others

arguments = json.loads(arguments_str)
assert "location" in arguments


def test_tool_call() -> None:
@tool
def search_tool(query: str) -> str:
"""Searches the web for `query` and returns the result."""
raise NotImplementedError

llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[search_tool])
response = llm.invoke("weather in san francisco")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "search_tool"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert "query" in arguments


class MyModel(BaseModel):
name: str
age: int


def test_pydantic_call() -> None:
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[MyModel])
response = llm.invoke("my name is Erick and I am 27 years old")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}