Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow callbacks after append and stream #6805

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/reference/chat/ChatFeed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"* **`placeholder_text`** (str): The text to display next to the placeholder icon.\n",
"* **`placeholder_params`** (dict) Defaults to `{\"user\": \" \", \"reaction_icons\": {}, \"show_copy_icon\": False, \"show_timestamp\": False}` Params to pass to the placeholder `ChatMessage`, like `reaction_icons`, `timestamp_format`, `show_avatar`, `show_user`, `show_timestamp`.\n",
"* **`placeholder_threshold`** (float): Min duration in seconds of buffering before displaying the placeholder. If 0, the placeholder will be disabled. Defaults to 0.2.\n",
"* **`post_hook`** (callable): A hook to execute after a new message is *completely* added, i.e. the generator is exhausted. The `stream` method will trigger this callback on every call. The signature must include the `message` and `instance` arguments.\n",
"* **`auto_scroll_limit`** (int): Max pixel distance from the latest object in the Column to activate automatic scrolling upon update. Setting to 0 disables auto-scrolling.\n",
"* **`scroll_button_threshold`** (int): Min pixel distance from the latest object in the Column to display the scroll button. Setting to 0 disables the scroll button.\n",
"* **`load_buffer`** (int): The number of objects loaded on each side of the visible objects. When scrolled halfway into the buffer, the feed will automatically load additional objects while unloading objects on the opposite side.\n",
Expand Down
39 changes: 33 additions & 6 deletions panel/chat/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ChatFeed(ListPanel):
auto_scroll_limit = param.Integer(default=200, bounds=(0, None), doc="""
Max pixel distance from the latest object in the Column to
activate automatic scrolling upon update. Setting to 0
disables auto-scrolling.""",)
disables auto-scrolling.""")

callback = param.Callable(allow_refs=False, doc="""
Callback to execute when a user sends a message or
Expand Down Expand Up @@ -133,6 +133,11 @@ class ChatFeed(ListPanel):
`help` as the user. This is useful for providing instructions,
and will not be included in the `serialize` method by default.""")

load_buffer = param.Integer(default=50, bounds=(0, None), doc="""
The number of objects loaded on each side of the visible objects.
When scrolled halfway into the buffer, the feed will automatically
load additional objects while unloading objects on the opposite side.""")

placeholder_text = param.String(default="", doc="""
The text to display next to the placeholder icon.""")

Expand All @@ -148,18 +153,19 @@ class ChatFeed(ListPanel):
Min duration in seconds of buffering before displaying the placeholder.
If 0, the placeholder will be disabled.""")

post_hook = param.Callable(allow_refs=False, doc="""
A hook to execute after a new message is *completely* added,
i.e. the generator is exhausted. The `stream` method will trigger
this callback on every call. The signature must include the
`message` and `instance` arguments.""")

renderers = param.HookList(doc="""
A callable or list of callables that accept the value and return a
Panel object to render the value. If a list is provided, will
attempt to use the first renderer that does not raise an
exception. If None, will attempt to infer the renderer
from the value.""")

load_buffer = param.Integer(default=50, bounds=(0, None), doc="""
The number of objects loaded on each side of the visible objects.
When scrolled halfway into the buffer, the feed will automatically
load additional objects while unloading objects on the opposite side.""")

scroll_button_threshold = param.Integer(default=100, bounds=(0, None),doc="""
Min pixel distance from the latest object in the Column to
display the scroll button. Setting to 0
Expand All @@ -182,6 +188,8 @@ class ChatFeed(ListPanel):

_callback_trigger = param.Event(doc="Triggers the callback to respond.")

_post_hook_trigger = param.Event(doc="Triggers the append callback.")

_disabled_stack = param.List(doc="""
The previous disabled state of the feed.""")

Expand Down Expand Up @@ -262,6 +270,7 @@ def __init__(self, *objects, **params):

# handle async callbacks using this trick
self.param.watch(self._prepare_response, '_callback_trigger')
self.param.watch(self._after_append_completed, '_post_hook_trigger')

def _get_model(
self, doc: Document, root: Model | None = None,
Expand Down Expand Up @@ -430,6 +439,7 @@ async def _serialize_response(self, response: Any) -> ChatMessage | None:
response_message = self._upsert_message(await response, response_message)
else:
response_message = self._upsert_message(response, response_message)
self.param.trigger("_post_hook_trigger")
finally:
if response_message:
response_message.show_activity_dot = False
Expand Down Expand Up @@ -484,6 +494,7 @@ async def _handle_callback(self, message, loop: asyncio.BaseEventLoop):
else:
response = await asyncio.to_thread(self.callback, *callback_args)
await self._serialize_response(response)
return response

async def _prepare_response(self, *_) -> None:
"""
Expand Down Expand Up @@ -580,6 +591,7 @@ def send(
value = {"object": value}
message = self._build_message(value, user=user, avatar=avatar)
self.append(message)
self.param.trigger("_post_hook_trigger")
if respond:
self.respond()
return message
Expand Down Expand Up @@ -644,6 +656,8 @@ def stream(
value = {"object": value}
message = self._build_message(value, user=user, avatar=avatar)
self._replace_placeholder(message)

self.param.trigger("_post_hook_trigger")
return message

def respond(self):
Expand Down Expand Up @@ -758,6 +772,19 @@ def _serialize_for_transformers(
serialized_messages.append({"role": role, "content": content})
return serialized_messages

async def _after_append_completed(self, message):
"""
Trigger the append callback after a message is added to the chat feed.
"""
if self.post_hook is None:
return

message = self._chat_log.objects[-1]
if iscoroutinefunction(self.post_hook):
await self.post_hook(message, self)
else:
self.post_hook(message, self)

def serialize(
self,
exclude_users: List[str] | None = None,
Expand Down
82 changes: 82 additions & 0 deletions panel/tests/chat/test_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,3 +997,85 @@ def test_invalid(self):
chat_feed = ChatFeed()
chat_feed.send("I'm a user", user="user")
chat_feed.serialize(format="atransform")


@pytest.mark.xdist_group("chat")
class TestChatFeedPostHook:

def test_return_string(self, chat_feed):
def callback(contents, user, instance):
yield f"Echo: {contents}"

def append_callback(message, instance):
logs.append(message.object)

logs = []
chat_feed.callback = callback
chat_feed.post_hook = append_callback
chat_feed.send("Hello World!")
wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!")
assert logs == ["Hello World!", "Echo: Hello World!"]

def test_yield_string(self, chat_feed):
def callback(contents, user, instance):
yield f"Echo: {contents}"

def append_callback(message, instance):
logs.append(message.object)

logs = []
chat_feed.callback = callback
chat_feed.post_hook = append_callback
chat_feed.send("Hello World!")
wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!")
assert logs == ["Hello World!", "Echo: Hello World!"]

def test_generator(self, chat_feed):
def callback(contents, user, instance):
message = "Echo: "
for char in contents:
message += char
yield message

def append_callback(message, instance):
logs.append(message.object)

logs = []
chat_feed.callback = callback
chat_feed.post_hook = append_callback
chat_feed.send("Hello World!")
wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!")
assert logs == ["Hello World!", "Echo: Hello World!"]

def test_async_generator(self, chat_feed):
async def callback(contents, user, instance):
message = "Echo: "
for char in contents:
message += char
yield message

async def append_callback(message, instance):
logs.append(message.object)

logs = []
chat_feed.callback = callback
chat_feed.post_hook = append_callback
chat_feed.send("Hello World!")
wait_until(lambda: chat_feed.objects[-1].object == "Echo: Hello World!")
assert logs == ["Hello World!", "Echo: Hello World!"]

def test_stream(self, chat_feed):
def callback(contents, user, instance):
message = instance.stream("Echo: ")
for char in contents:
message = instance.stream(char, message=message)

def append_callback(message, instance):
logs.append(message.object)

logs = []
chat_feed.callback = callback
chat_feed.post_hook = append_callback
chat_feed.send("AB")
wait_until(lambda: chat_feed.objects[-1].object == "Echo: AB")
assert logs == ["AB", "Echo: ", "Echo: AB"]
Loading