Skip to content

Commit 9d695eb

Browse files
committed
resolved patch_chat_model bug
1 parent 7ed0648 commit 9d695eb

File tree

3 files changed

+91
-24
lines changed

3 files changed

+91
-24
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tool-parse"
3-
version = "1.1.0"
3+
version = "1.1.1"
44
description = "Making LLM Tool-Calling Simpler."
55
authors = ["Harsh Verma <synacktra.work@gmail.com>"]
66
repository = "https://github.com/synacktraa/tool-parse"

tests/test_integration.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44

55
if sys.version_info >= (3, 9):
66
import asyncio
7-
from typing import Literal, NamedTuple
7+
from typing import Any, Literal, NamedTuple
88

9+
from langchain_core.language_models.fake_chat_models import FakeChatModel
910
from langchain_core.tools.structured import StructuredTool
11+
from langchain_core.utils.function_calling import convert_to_openai_tool
1012

11-
from tool_parse.integrations.langchain import ExtendedStructuredTool
13+
from tool_parse.integrations.langchain import ExtendedStructuredTool, patch_chat_model
1214

1315
@pytest.fixture
14-
def langchain_tools():
16+
def tools():
1517
async def search_web(query: str, safe_search: bool = True):
1618
"""
1719
Search the web.
@@ -32,18 +34,31 @@ class UserInfo(NamedTuple):
3234
ExtendedStructuredTool(func=UserInfo, name="user_info", schema_spec="claude"),
3335
]
3436

35-
def test_langchain_integration(langchain_tools):
37+
def test_langchain_tools(tools):
3638
async def __asyncio__():
37-
assert len(langchain_tools) == 2
39+
assert len(tools) == 2
3840

39-
assert langchain_tools[0].name == "search_web"
40-
assert (await langchain_tools[0].invoke(input={"query": "langchain"})) == "not found"
41+
assert tools[0].name == "search_web"
42+
assert (await tools[0].invoke(input={"query": "langchain"})) == "not found"
4143

42-
assert langchain_tools[1].name == "user_info"
43-
assert "input_schema" in langchain_tools[1].json_schema["function"]
44-
info = langchain_tools[1].invoke(input={"name": "synacktra", "age": "21"})
44+
assert tools[1].name == "user_info"
45+
assert "input_schema" in tools[1].json_schema["function"]
46+
info = tools[1].invoke(input={"name": "synacktra", "age": "21"})
4547
assert info.name == "synacktra"
4648
assert info.age == 21
4749
assert info.role == "tester"
4850

4951
asyncio.run(__asyncio__())
52+
53+
def test_langchain_chat_model(tools):
54+
class ChatMock(FakeChatModel):
55+
def bind_tools(self, tools: Any, **kwargs: Any):
56+
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
57+
return super().bind(tools=formatted_tools, **kwargs)
58+
59+
patched_model = patch_chat_model(ChatMock()).bind_tools(tools=tools)
60+
print(patched_model.kwargs["tools"])
61+
assert len(patched_model.kwargs["tools"]) == 2
62+
63+
patched_model = patch_chat_model(ChatMock)().bind_tools(tools=tools)
64+
assert len(patched_model.kwargs["tools"]) == 2

tool_parse/integrations/langchain.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import typing as t
1111
import uuid
1212
from contextvars import copy_context
13-
from types import MethodType
1413

1514
from langchain_core.callbacks import (
1615
AsyncCallbackManager,
@@ -38,6 +37,7 @@
3837
_handle_tool_error,
3938
_handle_validation_error,
4039
)
40+
from langchain_core.utils.function_calling import convert_to_openai_tool
4141
from pydantic import PrivateAttr, ValidationError, model_validator
4242

4343
from .. import _types as ts
@@ -346,6 +346,48 @@ async def arun( # noqa: C901
346346
ChatModel = t.TypeVar("ChatModel", bound=BaseChatModel)
347347

348348

349+
def _validate_tool_choice(
350+
choice: t.Union[dict, str, t.Literal["auto", "any", "none"], bool],
351+
tools: t.List[BaseTool],
352+
schema_list: t.List[t.Dict[str, t.Any]],
353+
):
354+
if choice == "any":
355+
if len(tools) > 1:
356+
raise ValueError(
357+
f"Groq does not currently support {choice=}. Should "
358+
f"be one of 'auto', 'none', or the name of the tool to call."
359+
)
360+
else:
361+
choice = convert_to_openai_tool(tools[0])["function"]["name"]
362+
if isinstance(choice, str) and (choice not in ("auto", "any", "none")):
363+
choice = {"type": "function", "function": {"name": choice}}
364+
# TODO: Remove this update once 'any' is supported.
365+
if isinstance(choice, dict) and (len(schema_list) != 1):
366+
raise ValueError(
367+
"When specifying `tool_choice`, you must provide exactly one "
368+
f"tool. Received {len(schema_list)} tools."
369+
)
370+
if isinstance(choice, dict) and (
371+
schema_list[0]["function"]["name"] != choice["function"]["name"]
372+
):
373+
raise ValueError(
374+
f"Tool choice {choice} was specified, but the only "
375+
f"provided tool was {schema_list[0]['function']['name']}."
376+
)
377+
if isinstance(choice, bool):
378+
if len(tools) > 1:
379+
raise ValueError(
380+
"tool_choice can only be True when there is one tool. Received "
381+
f"{len(tools)} tools."
382+
)
383+
tool_name = schema_list[0]["function"]["name"]
384+
choice = {
385+
"type": "function",
386+
"function": {"name": tool_name},
387+
}
388+
return choice
389+
390+
349391
@t.overload
350392
def patch_chat_model(__model: ChatModel) -> ChatModel:
351393
"""
@@ -379,25 +421,35 @@ def patch_chat_model(__model: type[ChatModel]) -> type[ChatModel]:
379421

380422

381423
def patch_chat_model(__model: t.Union[ChatModel, type[ChatModel]]):
382-
class PatchedModel(BaseChatModel):
424+
chat_model_cls = __model if isinstance(__model, type) else __model.__class__
425+
426+
class PatchedModel(chat_model_cls):
383427
def bind_tools(
384428
self,
385-
tools: t.Sequence[t.Any],
429+
tools: t.Sequence[t.Union[BaseTool, ExtendedStructuredTool]],
386430
**kwargs: t.Any,
387431
) -> Runnable[LanguageModelInput, BaseMessage]:
388-
schema_list = []
432+
formatted_tools = []
389433
for tool in tools:
390434
if isinstance(tool, ExtendedStructuredTool):
391-
schema_list.append(tool.json_schema)
435+
formatted_tools.append(tool.json_schema)
392436
else:
393-
schema_list.extend(super().bind_tools(tools=[tool], **kwargs))
394-
return self.bind(tools=schema_list, **kwargs)
437+
# Use the original bind_tools method for builtin tool types
438+
formatted_tools.extend(
439+
super().bind_tools(tools=[tool], **kwargs).kwargs["tools"]
440+
)
441+
442+
if tool_choice := kwargs.get("tool_choice", None):
443+
kwargs["tool_choice"] = _validate_tool_choice(
444+
choice=tool_choice, tools=tools, schema_list=formatted_tools
445+
)
446+
447+
return self.bind(tools=formatted_tools, **kwargs)
395448

396449
if isinstance(__model, type):
397-
# Patch the class
398-
__model.bind_tools = PatchedModel.bind_tools
450+
# Return the patched class
451+
return PatchedModel
399452
else:
400-
# Patch the instance (pydantic is weird)
401-
object.__setattr__(__model, "bind_tools", MethodType(PatchedModel.bind_tools, __model))
402-
403-
return __model
453+
# Patch the instance
454+
__model.__class__ = PatchedModel
455+
return __model

0 commit comments

Comments
 (0)