diff --git a/panel/chat/feed.py b/panel/chat/feed.py index 147c04ef7bb..0536bf22bd7 100644 --- a/panel/chat/feed.py +++ b/panel/chat/feed.py @@ -10,8 +10,8 @@ from enum import Enum from inspect import ( - isasyncgen, isasyncgenfunction, isawaitable, iscoroutinefunction, - isgenerator, isgeneratorfunction, + getfullargspec, isasyncgen, isasyncgenfunction, isawaitable, + iscoroutinefunction, isgenerator, isgeneratorfunction, ) from io import BytesIO from typing import ( @@ -430,7 +430,15 @@ def _gather_callback_args(self, message: ChatMessage) -> Any: contents = value.value else: contents = value - return contents, message.user, self + + callback_args = getfullargspec(self.callback).args + if len(callback_args) > 3: + raise ValueError('Function should have at most 3 arguments') + elif len(callback_args) == 0: + raise ValueError('Function should have at least one argument') + + input_args = (contents, message.user, self) + return input_args[:len(callback_args)] async def _serialize_response(self, response: Any) -> ChatMessage | None: """ diff --git a/panel/tests/chat/test_feed.py b/panel/tests/chat/test_feed.py index a696a01c063..458c54c9c58 100644 --- a/panel/tests/chat/test_feed.py +++ b/panel/tests/chat/test_feed.py @@ -888,6 +888,59 @@ def callback(contents, user, instance): assert feed.objects[-1].object == "helloooo" assert chat_feed._placeholder not in chat_feed._chat_log + def test_callback_one_argument(self, chat_feed): + def callback(contents): + return contents + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "Message" + + def test_callback_two_arguments(self, chat_feed): + def callback(contents, user): + return f"{user}: {contents}" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "User: Message" + + def test_callback_two_arguments_with_keyword(self, chat_feed): + def callback(contents, user=None): + return f"{user}: {contents}" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "User: Message" + + def test_callback_three_arguments_with_keyword(self, chat_feed): + def callback(contents, user=None, instance=None): + return f"{user}: {contents}" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "User: Message" + + def test_callback_two_arguments_yield(self, chat_feed): + def callback(contents, user): + yield f"{user}: {contents}" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "User: Message" + + def test_callback_two_arguments_async_yield(self, chat_feed): + async def callback(contents, user): + yield f"{user}: {contents}" + + chat_feed.callback = callback + chat_feed.send("Message", respond=True) + wait_until(lambda: len(chat_feed.objects) == 2) + assert chat_feed.objects[1].object == "User: Message" @pytest.mark.xdist_group("chat") class TestChatFeedSerializeForTransformers: