Skip to content

Commit

Permalink
Python: fix for kernel function type_object when not using Annotation…
Browse files Browse the repository at this point in the history
…s. Add more unit tests. (#7338)

### Motivation and Context

The current kernel function decorator class parses function params, but
had a miss related to how the underlying param's type is parsed if not
using Annotations.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

This PR properly handles the type_object for params that don't using
typing.Annotations.
- Adds unit tests to exercise the code.
- Adds more kernel.py unit tests for near 100% test coverage.
- Adds some planner tests to improve the test coverage.
- Closes #7300 

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Jul 18, 2024
1 parent f0e1e8c commit 6b01468
Show file tree
Hide file tree
Showing 9 changed files with 759 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.core_plugins import MathPlugin, TimePlugin
from semantic_kernel.core_plugins.math_plugin import MathPlugin
from semantic_kernel.core_plugins.time_plugin import TimePlugin
from semantic_kernel.functions import KernelArguments

if TYPE_CHECKING:
Expand Down
7 changes: 5 additions & 2 deletions python/semantic_kernel/functions/kernel_function_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types
from collections.abc import Callable
from inspect import Parameter, Signature, isasyncgenfunction, isclass, isgeneratorfunction, signature
from typing import Any, ForwardRef, Union, get_args
from typing import Annotated, Any, ForwardRef, Union, get_args, get_origin

NoneType = type(None)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,7 +103,10 @@ def _process_signature(func_sig: Signature) -> list[dict[str, Any]]:
annotation = arg.annotation
default = arg.default if arg.default != arg.empty else None
parsed_annotation = _parse_parameter(arg.name, annotation, default)
underlying_type = _get_underlying_type(annotation)
if get_origin(annotation) is Annotated or get_origin(annotation) in {Union, types.UnionType}:
underlying_type = _get_underlying_type(annotation)
else:
underlying_type = annotation
parsed_annotation["type_object"] = underlying_type
annotations.append(parsed_annotation)

Expand Down
48 changes: 30 additions & 18 deletions python/semantic_kernel/functions/kernel_function_from_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> N
function_arguments = self.gather_function_parameters(context)
context.result = FunctionResult(function=self.metadata, value=self.stream_method(**function_arguments))

def _parse_parameter(self, value: Any, param_type: Any) -> Any:
"""Parses the value into the specified param_type, including handling lists of types."""
if isinstance(param_type, type) and hasattr(param_type, "model_validate"):
try:
return param_type.model_validate(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter is expected to be parsed to {param_type} but is not."
) from exc
elif hasattr(param_type, "__origin__") and param_type.__origin__ is list:
if isinstance(value, list):
item_type = param_type.__args__[0]
return [self._parse_parameter(item, item_type) for item in value]
raise FunctionExecutionException(f"Expected a list for {param_type}, but got {type(value)}")
else:
try:
if isinstance(value, dict) and hasattr(param_type, "__init__"):
return param_type(**value)
return param_type(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter is expected to be parsed to {param_type} but is not."
) from exc

def gather_function_parameters(self, context: FunctionInvocationContext) -> dict[str, Any]:
"""Gathers the function parameters from the arguments."""
function_arguments: dict[str, Any] = {}
Expand All @@ -147,24 +171,12 @@ def gather_function_parameters(self, context: FunctionInvocationContext) -> dict
and param.type_object
and param.type_object is not inspect._empty
):
if hasattr(param.type_object, "model_validate"):
try:
value = param.type_object.model_validate(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter {param.name} is expected to be parsed to {param.type_} but is not."
) from exc
else:
try:
if isinstance(value, dict) and hasattr(param.type_object, "__init__"):
value = param.type_object(**value)
else:
value = param.type_object(value)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter {param.name} is expected to be parsed to "
f"{param.type_object} but is not."
) from exc
try:
value = self._parse_parameter(value, param.type_object)
except Exception as exc:
raise FunctionExecutionException(
f"Parameter {param.name} is expected to be parsed to {param.type_object} but is not."
) from exc
function_arguments[param.name] = value
continue
if param.is_required:
Expand Down
21 changes: 21 additions & 0 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING
from unittest.mock import MagicMock

import pytest

from semantic_kernel.contents.function_call_content import FunctionCallContent

if TYPE_CHECKING:
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
Expand Down Expand Up @@ -134,6 +137,24 @@ async def _invoke_internal(self, context: "FunctionInvocationContext"):
return create_mock_function


@pytest.fixture(scope="function")
def get_tool_call_mock():
tool_call_mock = MagicMock(spec=FunctionCallContent)
tool_call_mock.split_name_dict.return_value = {"arg_name": "arg_value"}
tool_call_mock.to_kernel_arguments.return_value = {"arg_name": "arg_value"}
tool_call_mock.name = "test-function"
tool_call_mock.function_name = "function"
tool_call_mock.plugin_name = "test"
tool_call_mock.arguments = {"arg_name": "arg_value"}
tool_call_mock.ai_model_id = None
tool_call_mock.metadata = {}
tool_call_mock.index = 0
tool_call_mock.parse_arguments.return_value = {"arg_name": "arg_value"}
tool_call_mock.id = "test_id"

return tool_call_mock


@pytest.fixture(scope="function")
def chat_history() -> "ChatHistory":
from semantic_kernel.contents.chat_history import ChatHistory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def func_return_type_annotated(self, input: str) -> Annotated[str, "test return"
return input

@kernel_function
def func_return_type_streaming(self, input: str) -> Annotated[AsyncGenerator[str, Any], "test return"]:
def func_return_type_streaming(self, input: str) -> Annotated[AsyncGenerator[str, Any], "test return"]: # type: ignore
yield input

@kernel_function
Expand Down
111 changes: 111 additions & 0 deletions python/tests/unit/functions/test_kernel_function_from_method.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncGenerator, Iterable
from typing import Annotated, Any
from unittest.mock import Mock

import pytest

from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
from semantic_kernel.exceptions import FunctionExecutionException, FunctionInitializationError
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.filters.kernel_filters_extension import _rebuild_function_invocation_context
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function import KernelFunction
Expand All @@ -16,6 +19,38 @@
from semantic_kernel.kernel_pydantic import KernelBaseModel


class CustomType(KernelBaseModel):
id: str
name: str


class CustomTypeNonPydantic:
id: str
name: str

def __init__(self, id: str, name: str):
self.id = id
self.name = name


@pytest.fixture
def get_custom_type_function_pydantic():
@kernel_function
def func_default(param: list[CustomType]):
return input

return KernelFunction.from_method(func_default, "test")


@pytest.fixture
def get_custom_type_function_nonpydantic():
@kernel_function
def func_default(param: list[CustomTypeNonPydantic]):
return input

return KernelFunction.from_method(func_default, "test")


def test_init_native_function_with_input_description():
@kernel_function(description="Mock description", name="mock_function")
def mock_function(input: Annotated[str, "input"], arguments: "KernelArguments") -> None:
Expand Down Expand Up @@ -449,3 +484,79 @@ def func_default(base: str, input: str = "test"):

res = await kernel.invoke(func, base="base")
assert str(res) == "test"


def test_parse_list_of_objects(get_custom_type_function_pydantic):
func = get_custom_type_function_pydantic

param_type = list[CustomType]
value = [{"id": "1", "name": "John"}, {"id": "2", "name": "Jane"}]
result = func._parse_parameter(value, param_type)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(item, CustomType) for item in result)


def test_parse_individual_object(get_custom_type_function_pydantic):
value = {"id": "2", "name": "Jane"}
func = get_custom_type_function_pydantic
result = func._parse_parameter(value, CustomType)
assert isinstance(result, CustomType)
assert result.id == "2"
assert result.name == "Jane"


def test_parse_non_list_raises_exception(get_custom_type_function_pydantic):
func = get_custom_type_function_pydantic
param_type = list[CustomType]
value = {"id": "2", "name": "Jane"}
with pytest.raises(FunctionExecutionException, match=r"Expected a list for .*"):
func._parse_parameter(value, param_type)


def test_parse_invalid_dict_raises_exception(get_custom_type_function_pydantic):
func = get_custom_type_function_pydantic
value = {"id": "1"}
with pytest.raises(FunctionExecutionException, match=r"Parameter is expected to be parsed to .*"):
func._parse_parameter(value, CustomType)


def test_parse_invalid_value_raises_exception(get_custom_type_function_pydantic):
func = get_custom_type_function_pydantic
value = "invalid_value"
with pytest.raises(FunctionExecutionException, match=r"Parameter is expected to be parsed to .*"):
func._parse_parameter(value, CustomType)


def test_parse_invalid_list_raises_exception(get_custom_type_function_pydantic):
func = get_custom_type_function_pydantic
param_type = list[CustomType]
value = ["invalid_value"]
with pytest.raises(FunctionExecutionException, match=r"Parameter is expected to be parsed to .*"):
func._parse_parameter(value, param_type)


def test_parse_dict_with_init_non_pydantic(get_custom_type_function_nonpydantic):
func = get_custom_type_function_nonpydantic
value = {"id": "3", "name": "Alice"}
result = func._parse_parameter(value, CustomTypeNonPydantic)
assert isinstance(result, CustomTypeNonPydantic)
assert result.id == "3"
assert result.name == "Alice"


def test_parse_invalid_dict_raises_exception_new(get_custom_type_function_nonpydantic):
func = get_custom_type_function_nonpydantic
value = {"wrong_key": "3", "name": "Alice"}
with pytest.raises(FunctionExecutionException, match=r"Parameter is expected to be parsed to .*"):
func._parse_parameter(value, CustomTypeNonPydantic)


def test_gather_function_parameters_exception_handling(get_custom_type_function_pydantic):
kernel = Mock(spec=Kernel) # Mock kernel
func = get_custom_type_function_pydantic
_rebuild_function_invocation_context()
context = FunctionInvocationContext(kernel=kernel, function=func, arguments=KernelArguments(param="test"))

with pytest.raises(FunctionExecutionException, match=r"Parameter param is expected to be parsed to .* but is not."):
func.gather_function_parameters(context)
Loading

0 comments on commit 6b01468

Please sign in to comment.