|
10 | 10 | import typing as t
|
11 | 11 | import uuid
|
12 | 12 | from contextvars import copy_context
|
13 |
| -from types import MethodType |
14 | 13 |
|
15 | 14 | from langchain_core.callbacks import (
|
16 | 15 | AsyncCallbackManager,
|
|
38 | 37 | _handle_tool_error,
|
39 | 38 | _handle_validation_error,
|
40 | 39 | )
|
| 40 | +from langchain_core.utils.function_calling import convert_to_openai_tool |
41 | 41 | from pydantic import PrivateAttr, ValidationError, model_validator
|
42 | 42 |
|
43 | 43 | from .. import _types as ts
|
@@ -346,6 +346,48 @@ async def arun( # noqa: C901
|
346 | 346 | ChatModel = t.TypeVar("ChatModel", bound=BaseChatModel)
|
347 | 347 |
|
348 | 348 |
|
| 349 | +def _validate_tool_choice( |
| 350 | + choice: t.Union[dict, str, t.Literal["auto", "any", "none"], bool], |
| 351 | + tools: t.List[BaseTool], |
| 352 | + schema_list: t.List[t.Dict[str, t.Any]], |
| 353 | +): |
| 354 | + if choice == "any": |
| 355 | + if len(tools) > 1: |
| 356 | + raise ValueError( |
| 357 | + f"Groq does not currently support {choice=}. Should " |
| 358 | + f"be one of 'auto', 'none', or the name of the tool to call." |
| 359 | + ) |
| 360 | + else: |
| 361 | + choice = convert_to_openai_tool(tools[0])["function"]["name"] |
| 362 | + if isinstance(choice, str) and (choice not in ("auto", "any", "none")): |
| 363 | + choice = {"type": "function", "function": {"name": choice}} |
| 364 | + # TODO: Remove this update once 'any' is supported. |
| 365 | + if isinstance(choice, dict) and (len(schema_list) != 1): |
| 366 | + raise ValueError( |
| 367 | + "When specifying `tool_choice`, you must provide exactly one " |
| 368 | + f"tool. Received {len(schema_list)} tools." |
| 369 | + ) |
| 370 | + if isinstance(choice, dict) and ( |
| 371 | + schema_list[0]["function"]["name"] != choice["function"]["name"] |
| 372 | + ): |
| 373 | + raise ValueError( |
| 374 | + f"Tool choice {choice} was specified, but the only " |
| 375 | + f"provided tool was {schema_list[0]['function']['name']}." |
| 376 | + ) |
| 377 | + if isinstance(choice, bool): |
| 378 | + if len(tools) > 1: |
| 379 | + raise ValueError( |
| 380 | + "tool_choice can only be True when there is one tool. Received " |
| 381 | + f"{len(tools)} tools." |
| 382 | + ) |
| 383 | + tool_name = schema_list[0]["function"]["name"] |
| 384 | + choice = { |
| 385 | + "type": "function", |
| 386 | + "function": {"name": tool_name}, |
| 387 | + } |
| 388 | + return choice |
| 389 | + |
| 390 | + |
349 | 391 | @t.overload
|
350 | 392 | def patch_chat_model(__model: ChatModel) -> ChatModel:
|
351 | 393 | """
|
@@ -379,25 +421,35 @@ def patch_chat_model(__model: type[ChatModel]) -> type[ChatModel]:
|
379 | 421 |
|
380 | 422 |
|
381 | 423 | def patch_chat_model(__model: t.Union[ChatModel, type[ChatModel]]):
|
382 |
| - class PatchedModel(BaseChatModel): |
| 424 | + chat_model_cls = __model if isinstance(__model, type) else __model.__class__ |
| 425 | + |
| 426 | + class PatchedModel(chat_model_cls): |
383 | 427 | def bind_tools(
|
384 | 428 | self,
|
385 |
| - tools: t.Sequence[t.Any], |
| 429 | + tools: t.Sequence[t.Union[BaseTool, ExtendedStructuredTool]], |
386 | 430 | **kwargs: t.Any,
|
387 | 431 | ) -> Runnable[LanguageModelInput, BaseMessage]:
|
388 |
| - schema_list = [] |
| 432 | + formatted_tools = [] |
389 | 433 | for tool in tools:
|
390 | 434 | if isinstance(tool, ExtendedStructuredTool):
|
391 |
| - schema_list.append(tool.json_schema) |
| 435 | + formatted_tools.append(tool.json_schema) |
392 | 436 | else:
|
393 |
| - schema_list.extend(super().bind_tools(tools=[tool], **kwargs)) |
394 |
| - return self.bind(tools=schema_list, **kwargs) |
| 437 | + # Use the original bind_tools method for builtin tool types |
| 438 | + formatted_tools.extend( |
| 439 | + super().bind_tools(tools=[tool], **kwargs).kwargs["tools"] |
| 440 | + ) |
| 441 | + |
| 442 | + if tool_choice := kwargs.get("tool_choice", None): |
| 443 | + kwargs["tool_choice"] = _validate_tool_choice( |
| 444 | + choice=tool_choice, tools=tools, schema_list=formatted_tools |
| 445 | + ) |
| 446 | + |
| 447 | + return self.bind(tools=formatted_tools, **kwargs) |
395 | 448 |
|
396 | 449 | if isinstance(__model, type):
|
397 |
| - # Patch the class |
398 |
| - __model.bind_tools = PatchedModel.bind_tools |
| 450 | + # Return the patched class |
| 451 | + return PatchedModel |
399 | 452 | else:
|
400 |
| - # Patch the instance (pydantic is weird) |
401 |
| - object.__setattr__(__model, "bind_tools", MethodType(PatchedModel.bind_tools, __model)) |
402 |
| - |
403 |
| - return __model |
| 453 | + # Patch the instance |
| 454 | + __model.__class__ = PatchedModel |
| 455 | + return __model |
0 commit comments