Skip to content

Commit 46f7ee4

Browse files
authored
Follow up to #1453: allow user roles when normalizing a dictionary (#1495)
1 parent 36a5308 commit 46f7ee4

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

shiny/ui/_chat_normalize.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ def can_normalize_chunk(self, chunk: Any) -> bool:
5757

5858
class DictNormalizer(BaseMessageNormalizer):
5959
def normalize(self, message: Any) -> ChatMessage:
60-
x = self._check_dict(message)
60+
x = cast("dict[str, Any]", message)
61+
if "content" not in x:
62+
raise ValueError("Message must have 'content' key")
6163
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
6264

6365
def normalize_chunk(self, chunk: Any) -> ChatMessage:
64-
x = self._check_dict(chunk)
66+
x = cast("dict[str, Any]", chunk)
67+
if "content" not in x:
68+
raise ValueError("Message must have 'content' key")
6569
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
6670

6771
def can_normalize(self, message: Any) -> bool:
@@ -70,14 +74,6 @@ def can_normalize(self, message: Any) -> bool:
7074
def can_normalize_chunk(self, chunk: Any) -> bool:
7175
return isinstance(chunk, dict)
7276

73-
@staticmethod
74-
def _check_dict(x: Any) -> "dict[str, Any]":
75-
if "content" not in x:
76-
raise ValueError("Message must have 'content' key")
77-
if "role" in x and x["role"] not in ["assistant", "system"]:
78-
raise ValueError("Role must be 'assistant' or 'system")
79-
return x
80-
8177

8278
class LangChainNormalizer(BaseMessageNormalizer):
8379
def normalize(self, message: Any) -> ChatMessage:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from shiny import reactive
2+
from shiny.express import render, ui
3+
4+
chat = ui.Chat(id="chat")
5+
chat.ui()
6+
7+
8+
@reactive.effect
9+
async def _():
10+
await chat.append_message({"content": "A user message", "role": "user"})
11+
12+
13+
"chat.messages():"
14+
15+
16+
@render.code
17+
def message_state():
18+
return str(chat.messages())
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from playwright.sync_api import Page, expect
2+
3+
from shiny.playwright import controller
4+
from shiny.run import ShinyAppProc
5+
6+
7+
def test_validate_chat_append_user_message(page: Page, local_app: ShinyAppProc) -> None:
8+
page.goto(local_app.url)
9+
10+
chat = controller.Chat(page, "chat")
11+
12+
# Verify starting state
13+
expect(chat.loc).to_be_visible()
14+
chat.expect_latest_message("A user message")
15+
16+
# Verify that the message state is as expected
17+
message_state = controller.OutputCode(page, "message_state")
18+
message_state_expected = ({"content": "A user message", "role": "user"},)
19+
message_state.expect_value(str(message_state_expected))

0 commit comments

Comments
 (0)