Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Allow enums to be passed in to on_function_result. Improve handling FRC result so it can be hashed. #10316

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from semantic_kernel.contents.kernel_content import KernelContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.hashing import make_hashable
from semantic_kernel.exceptions.content_exceptions import ContentInitializationError

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,10 +195,11 @@ def serialize_result(self, value: Any) -> str:

def __hash__(self) -> int:
"""Return the hash of the function result content."""
hashable_result = make_hashable(self.result)
return hash((
self.tag,
self.id,
tuple(self.result) if isinstance(self.result, list) else self.result,
hashable_result,
self.name,
self.function_name,
self.plugin_name,
Expand Down
25 changes: 25 additions & 0 deletions python/semantic_kernel/contents/utils/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Any

from pydantic import BaseModel


def make_hashable(input: Any) -> Any:
"""Recursively convert unhashable types to hashable equivalents.

Args:
input: The input to convert to a hashable type.

Returns:
Any: The input converted to a hashable type.
"""
if isinstance(input, dict):
return tuple(sorted((k, make_hashable(v)) for k, v in input.items()))
if isinstance(input, (list, set, tuple)):
# Convert lists, sets, and tuples to tuples so they can be hashed
return tuple(make_hashable(item) for item in input)
if isinstance(input, BaseModel):
# If obj is a Pydantic model, convert it to a dict and process
return make_hashable(input.model_dump())
return input # Return the input if it's already hashable
15 changes: 12 additions & 3 deletions python/semantic_kernel/processes/process_step_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,18 @@ def build_step(self) -> "KernelProcessStepInfo":
# Return an instance of KernelProcessStepInfo with the built state and edges.
return KernelProcessStepInfo(inner_step_type=step_cls, state=state_object, output_edges=built_edges)

def on_function_result(self, function_name: str) -> "ProcessStepEdgeBuilder":
"""Creates a new ProcessStepEdgeBuilder for the function result."""
return self.on_event(f"{function_name}.OnResult")
def on_function_result(self, function_name: str | Enum) -> "ProcessStepEdgeBuilder":
"""Creates a new ProcessStepEdgeBuilder for the function result.

Args:
function_name: The function name as a string or Enum.

Returns:
ProcessStepEdgeBuilder: The ProcessStepEdgeBuilder instance.
"""
function_name_str: str = function_name.value if isinstance(function_name, Enum) else function_name

return self.on_event(f"{function_name_str}.OnResult")

def get_function_metadata_map(
self, plugin_type, name: str | None = None, kernel: "Kernel | None" = None
Expand Down
93 changes: 93 additions & 0 deletions python/tests/unit/contents/test_function_result_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import Mock

import pytest
from pydantic import BaseModel, ConfigDict, Field

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand All @@ -12,6 +13,7 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.kernel_pydantic import KernelBaseModel


class CustomResultClass:
Expand All @@ -34,6 +36,17 @@ def __str__(self):
return f"CustomObjectWithList({self.items})"


class AccountBalanceFrozen(KernelBaseModel):
# Make the model frozen so it's hashable
balance: int = Field(..., alias="account_balance")
model_config = ConfigDict(frozen=True)


class AccountBalanceNonFrozen(KernelBaseModel):
# This model is not frozen and thus not hashable by default
balance: int = Field(..., alias="account_balance")


def test_init():
frc = FunctionResultContent(id="test", name="test-function", result="test-result", metadata={"test": "test"})
assert frc.name == "test-function"
Expand Down Expand Up @@ -124,3 +137,83 @@ def __str__(self) -> str:
frc.model_dump_json(exclude_none=True)
== """{"metadata":{},"content_type":"function_result","id":"test","result":"test","name":"test-function","function_name":"function","plugin_name":"test"}""" # noqa: E501
)


def test_hash_with_frozen_account_balance():
balance = AccountBalanceFrozen(account_balance=100)
content = FunctionResultContent(
id="test_id",
result=balance,
function_name="TestFunction",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with frozen model should not raise errors."


def test_hash_with_dict_result():
balance_dict = {"account_balance": 100}
content = FunctionResultContent(
id="test_id",
result=balance_dict,
function_name="TestFunction",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with dict result should not raise errors."


def test_hash_with_nested_dict_result():
nested_dict = {"account_balance": 100, "details": {"currency": "USD", "last_updated": "2025-01-28"}}
content = FunctionResultContent(
id="test_id_nested",
result=nested_dict,
function_name="TestFunctionNested",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with nested dict result should not raise errors."


def test_hash_with_list_result():
balance_list = [100, 200, 300]
content = FunctionResultContent(
id="test_id_list",
result=balance_list,
function_name="TestFunctionList",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with list result should not raise errors."


def test_hash_with_set_result():
balance_set = {100, 200, 300}
content = FunctionResultContent(
id="test_id_set",
result=balance_set,
function_name="TestFunctionSet",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with set result should not raise errors."


def test_hash_with_custom_object_result():
class CustomObject(BaseModel):
field1: str
field2: int

custom_obj = CustomObject(field1="value1", field2=42)
content = FunctionResultContent(
id="test_id_custom",
result=custom_obj,
function_name="TestFunctionCustom",
)
_ = hash(content)
assert True, "Hashing FunctionResultContent with custom object result should not raise errors."


def test_unhashable_non_frozen_model_raises_type_error():
balance = AccountBalanceNonFrozen(account_balance=100)
content = FunctionResultContent(
id="test_id_unhashable",
result=balance,
function_name="TestFunctionUnhashable",
)
_ = hash(content)
26 changes: 26 additions & 0 deletions python/tests/unit/processes/test_process_step_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from enum import Enum
from unittest.mock import MagicMock

import pytest
Expand All @@ -14,6 +15,10 @@
from semantic_kernel.processes.process_step_edge_builder import ProcessStepEdgeBuilder


class TestFunctionEnum(Enum):
MY_FUNCTION = "my_function"


class MockKernelProcessStep(KernelProcessStep):
"""A mock class to use as a step type."""

Expand Down Expand Up @@ -177,3 +182,24 @@ def test_link_to_multiple_edges():

# Assert
assert step_builder.edges[event_id] == [edge_builder_1, edge_builder_2]


@pytest.mark.parametrize(
"function_name, expected_function_name",
[
("my_function", "my_function"),
(TestFunctionEnum.MY_FUNCTION, TestFunctionEnum.MY_FUNCTION.value),
],
)
def test_on_function_result(function_name, expected_function_name):
# Arrange
name = "test_step"
step_builder = ProcessStepBuilder(name=name)

# Act
edge_builder = step_builder.on_function_result(function_name=function_name)

# Assert
assert isinstance(edge_builder, ProcessStepEdgeBuilder)
assert edge_builder.source == step_builder
assert edge_builder.event_id == f"{step_builder.event_namespace}.{expected_function_name}.OnResult"
Loading