Skip to content

Commit 9a65c12

Browse files
aymeric-roucherelvircrn
authored andcommitted
Tool calling: support more types (huggingface#35776)
* Tool calling: support NoneType for function return type
1 parent 83d2a94 commit 9a65c12

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

src/transformers/utils/chat_template_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from contextlib import contextmanager
2020
from datetime import datetime
2121
from functools import lru_cache
22+
from types import NoneType
2223
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
2324

2425
from packaging import version
@@ -77,6 +78,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
7778
float: {"type": "number"},
7879
str: {"type": "string"},
7980
bool: {"type": "boolean"},
81+
NoneType: {"type": "null"},
8082
Any: {},
8183
}
8284
if is_vision_available():

tests/utils/test_chat_template_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,31 @@ def fn(x: int, y: int):
419419

420420
self.assertEqual(schema["function"], expected_schema)
421421

422+
def test_return_none(self):
423+
def fn(x: int) -> None:
424+
"""
425+
Test function
426+
427+
Args:
428+
x: The first input
429+
"""
430+
pass
431+
432+
schema = get_json_schema(fn)
433+
expected_schema = {
434+
"name": "fn",
435+
"description": "Test function",
436+
"parameters": {
437+
"type": "object",
438+
"properties": {
439+
"x": {"type": "integer", "description": "The first input"},
440+
},
441+
"required": ["x"],
442+
},
443+
"return": {"type": "null"},
444+
}
445+
self.assertEqual(schema["function"], expected_schema)
446+
422447
def test_everything_all_at_once(self):
423448
def fn(
424449
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")

0 commit comments

Comments
 (0)