diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 2fa65902092cc..94ceb0d709046 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -1055,10 +1055,12 @@ def add(a: int, b: int) -> int: ) +# TODO: Type args_schema as TypeBaseModel if we can get mypy to correctly recognize +# pydantic v2 BaseModel classes. def tool( *args: Union[str, Callable, Runnable], return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, + args_schema: Optional[Type] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index e1ed53d1ed45b..fb2f3488e0c5b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -29,8 +29,6 @@ def get_pydantic_major_version() -> int: PydanticBaseModel = pydantic.BaseModel TypeBaseModel = Type[BaseModel] elif PYDANTIC_MAJOR_VERSION == 2: - from pydantic.v1 import BaseModel # pydantic: ignore - # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore @@ -199,12 +197,12 @@ def _create_subset_model_v1( def _create_subset_model_v2( name: str, - model: Type[BaseModel], + model: Type[pydantic.BaseModel], field_names: List[str], *, descriptions: Optional[dict] = None, fn_description: Optional[str] = None, -) -> Type[BaseModel]: +) -> Type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" from pydantic import create_model # pydantic: ignore from pydantic.fields import FieldInfo # pydantic: ignore @@ -214,10 +212,10 @@ def _create_subset_model_v2( for field_name in field_names: field = model.model_fields[field_name] # type: ignore description = descriptions_.get(field_name, field.description) - fields[field_name] = ( - field.annotation, - FieldInfo(description=description, default=field.default), - ) + field_info = FieldInfo(description=description, default=field.default) + if field.metadata: + field_info.metadata = field.metadata + fields[field_name] = (field.annotation, field_info) rtn = create_model(name, **fields) # type: ignore rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") @@ -230,7 +228,7 @@ def _create_subset_model_v2( # However, can't find a way to type hint this. def _create_subset_model( name: str, - model: Type[BaseModel], + model: TypeBaseModel, field_names: List[str], *, descriptions: Optional[dict] = None, diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 0828066d99a7d..19b24b041a000 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1863,3 +1863,41 @@ class ModelD(ModelC, Generic[D]): } actual = _get_all_basemodel_annotations(ModelD[int]) assert actual == expected + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") +def test_tool_args_schema_pydantic_v2_with_metadata() -> None: + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic import Field as FieldV2 # pydantic: ignore + from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore + + class Foo(BaseModelV2): + x: List[int] = FieldV2( + description="List of integers", min_length=10, max_length=15 + ) + + @tool(args_schema=Foo) + def foo(x): # type: ignore[no-untyped-def] + """foo""" + return x + + assert foo.tool_call_schema.schema() == { + "description": "foo", + "properties": { + "x": { + "description": "List of integers", + "items": {"type": "integer"}, + "maxItems": 15, + "minItems": 10, + "title": "X", + "type": "array", + } + }, + "required": ["x"], + "title": "foo", + "type": "object", + } + + assert foo.invoke({"x": [0] * 10}) + with pytest.raises(ValidationErrorV2): + foo.invoke({"x": [0] * 9}) diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index 57f4538ed5a59..3f0df33a89b04 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -1,10 +1,13 @@ """Test for some custom pydantic decorators.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +import pytest from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.pydantic import ( PYDANTIC_MAJOR_VERSION, + _create_subset_model_v2, is_basemodel_instance, is_basemodel_subclass, pre_init, @@ -121,3 +124,32 @@ class Bar(BaseModelV1): assert is_basemodel_instance(Bar(x=5)) else: raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + + +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") +def test_with_field_metadata() -> None: + """Test pydantic with field metadata""" + from pydantic import BaseModel as BaseModelV2 # pydantic: ignore + from pydantic import Field as FieldV2 # pydantic: ignore + + class Foo(BaseModelV2): + x: List[int] = FieldV2( + description="List of integers", min_length=10, max_length=15 + ) + + subset_model = _create_subset_model_v2("Foo", Foo, ["x"]) + assert subset_model.model_json_schema() == { + "properties": { + "x": { + "description": "List of integers", + "items": {"type": "integer"}, + "maxItems": 15, + "minItems": 10, + "title": "X", + "type": "array", + } + }, + "required": ["x"], + "title": "Foo", + "type": "object", + } diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 32d7af5f5f7a5..ffec9dc90bc46 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -18,6 +18,8 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel as RawBaseModel +from pydantic import Field as RawField from langchain_standard_tests.unit_tests.chat_models import ( ChatModelTests, @@ -26,7 +28,11 @@ from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION -@tool +class MagicFunctionSchema(RawBaseModel): + input: int = RawField(..., gt=-1000, lt=1000) + + +@tool(args_schema=MagicFunctionSchema) def magic_function(input: int) -> int: """Applies a magic function to an input.""" return input + 2