Skip to content

Commit

Permalink
feat/add ask user action (#473)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Sirieix <clementsirieix@clements-mbp.home>
  • Loading branch information
clementsirieix and Clément Sirieix authored Oct 13, 2023
1 parent 942a658 commit 9fd2804
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 23 deletions.
9 changes: 8 additions & 1 deletion backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
Video,
)
from chainlit.logger import logger
from chainlit.message import AskFileMessage, AskUserMessage, ErrorMessage, Message
from chainlit.message import (
AskActionMessage,
AskFileMessage,
AskUserMessage,
ErrorMessage,
Message,
)
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.sync import make_async, run_sync
from chainlit.telemetry import trace
Expand Down Expand Up @@ -317,6 +323,7 @@ def sleep(duration: int):
"Message",
"ErrorMessage",
"AskUserMessage",
"AskActionMessage",
"AskFileMessage",
"on_chat_start",
"on_chat_end",
Expand Down
84 changes: 82 additions & 2 deletions backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

from chainlit.action import Action
from chainlit.client.base import MessageDict
Expand All @@ -13,7 +13,14 @@
from chainlit.logger import logger
from chainlit.prompt import Prompt
from chainlit.telemetry import trace_event
from chainlit.types import AskFileResponse, AskFileSpec, AskResponse, AskSpec
from chainlit.types import (
AskActionResponse,
AskActionSpec,
AskFileResponse,
AskFileSpec,
AskResponse,
AskSpec,
)
from syncer import asyncio


Expand Down Expand Up @@ -485,3 +492,76 @@ async def send(self) -> Union[List[AskFileResponse], None]:
return [AskFileResponse(**r) for r in res]
else:
return None


class AskActionMessage(AskMessageBase):
"""
Ask the user to select an action before continuing.
If the user does not answer in time (see timeout), a TimeoutError will be raised or None will be returned depending on raise_on_timeout.
"""

def __init__(
self,
content: str,
actions: List[Action],
author=config.ui.name,
disable_human_feedback=False,
timeout=90,
raise_on_timeout=False,
):
self.content = content
self.actions = actions
self.author = author
self.disable_human_feedback = disable_human_feedback
self.timeout = timeout
self.raise_on_timeout = raise_on_timeout

super().__post_init__()

def to_dict(self):
return {
"id": self.id,
"createdAt": self.created_at,
"content": self.content,
"author": self.author,
"waitForAnswer": True,
"disableHumanFeedback": self.disable_human_feedback,
"timeout": self.timeout,
"raiseOnTimeout": self.raise_on_timeout,
}

async def send(self) -> Union[AskActionResponse, None]:
"""
Sends the question to ask to the UI and waits for the reply
"""
trace_event("send_ask_action")

if self.streaming:
self.streaming = False

if config.code.author_rename:
self.author = await config.code.author_rename(self.author)

msg_dict = await self._create()
action_keys = []

for action in self.actions:
action_keys.append(action.id)
await action.send(for_id=str(msg_dict["id"]))

spec = AskActionSpec(type="action", timeout=self.timeout, keys=action_keys)

res = cast(
Union[AskActionResponse, None],
await context.emitter.send_ask_user(msg_dict, spec, self.raise_on_timeout),
)

for action in self.actions:
await action.remove()
if res is None:
self.content = "Timed out: no action was taken"
else:
self.content = f'**Selected action:** {res["label"]}'
await self.update()

return res
22 changes: 21 additions & 1 deletion backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,29 @@ class FileSpec(DataClassJsonMixin):
max_size_mb: int


@dataclass
class ActionSpec(DataClassJsonMixin):
keys: List[str]


@dataclass
class AskSpec(DataClassJsonMixin):
"""Specification for asking the user."""

timeout: int
type: Literal["text", "file"]
type: Literal["text", "file", "action"]


@dataclass
class AskFileSpec(FileSpec, AskSpec, DataClassJsonMixin):
"""Specification for asking the user a file."""


@dataclass
class AskActionSpec(ActionSpec, AskSpec, DataClassJsonMixin):
"""Specification for asking the user an action"""


class AskResponse(TypedDict):
content: str
author: str
Expand All @@ -46,6 +56,16 @@ class AskFileResponse:
content: bytes


class AskActionResponse(TypedDict):
name: str
value: str
label: str
description: str
forId: str
id: str
collapsed: bool


class CompletionRequest(BaseModel):
prompt: Prompt
userEnv: Dict[str, str]
Expand Down
21 changes: 21 additions & 0 deletions cypress/e2e/action/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ async def main():
message = cl.Message("Hello, this is a test message!", actions=actions)
cl.user_session.set("to_remove", message)
await message.send()

result = await cl.AskActionMessage(
content="Please, pick an action!",
actions=[
cl.Action(
id="first-action",
name="first_action",
value="first-action",
label="First action",
),
cl.Action(
id="second-action",
name="second_action",
value="second-action",
label="Second action",
),
],
).send()

if result != None:
await cl.Message(f"Thanks for pressing: {result['value']}").send()
26 changes: 17 additions & 9 deletions cypress/e2e/action/spec.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,55 @@ describe('Action', () => {
});

it('should correctly execute and display actions', () => {
// Click on "first action"
cy.get('#first-action').should('be.visible');
cy.get('#first-action').click();
cy.get('.message').should('have.length', 3);
cy.get('.message')
.eq(2)
.should('contain', 'Thanks for pressing: first-action');

// Click on "test action"
cy.get("[id='test-action']").should('be.visible');
cy.get("[id='test-action']").click();
cy.get('.message').should('have.length', 2);
cy.get('.message').eq(1).should('contain', 'Executed test action!');
cy.get('.message').should('have.length', 4);
cy.get('.message').eq(3).should('contain', 'Executed test action!');
cy.get("[id='test-action']").should('exist');

cy.wait(100);

// Click on "removable action"
cy.get("[id='removable-action']").should('be.visible');
cy.get("[id='removable-action']").click();
cy.get('.message').should('have.length', 3);
cy.get('.message').eq(2).should('contain', 'Executed removable action!');
cy.get('.message').should('have.length', 5);
cy.get('.message').eq(4).should('contain', 'Executed removable action!');
cy.get("[id='removable-action']").should('not.exist');

cy.wait(100);

// Click on "multiple action one" in the action drawer, should remove the correct action button
cy.get("[id='actions-drawer-button']").should('be.visible');
cy.get("[id='actions-drawer-button']").click();
cy.get('.message').should('have.length', 3);
cy.get('.message').should('have.length', 5);

cy.wait(100);

cy.get("[id='multiple-action-one']").should('be.visible');
cy.get("[id='multiple-action-one']").click();
cy.get('.message')
.eq(3)
.eq(5)
.should('contain', 'Action(id=multiple-action-one) has been removed!');
cy.get("[id='multiple-action-one']").should('not.exist');

cy.wait(100);

// Click on "multiple action two", should remove the correct action button
cy.get('.message').should('have.length', 4);
cy.get('.message').should('have.length', 6);
cy.get("[id='actions-drawer-button']").click();
cy.get("[id='multiple-action-two']").should('be.visible');
cy.get("[id='multiple-action-two']").click();
cy.get('.message')
.eq(4)
.eq(6)
.should('contain', 'Action(id=multiple-action-two) has been removed!');
cy.get("[id='multiple-action-two']").should('not.exist');

Expand All @@ -56,7 +64,7 @@ describe('Action', () => {
cy.get("[id='all-actions-removed']").should('be.visible');
cy.get("[id='all-actions-removed']").click();
cy.get('.message')
.eq(5)
.eq(7)
.should('contain', 'All actions have been removed!');
cy.get("[id='all-actions-removed']").should('not.exist');
cy.get("[id='test-action']").should('not.exist');
Expand Down
6 changes: 5 additions & 1 deletion frontend/src/components/organisms/chat/inputBox/input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ const Input = ({ onSubmit, onReply }: Props) => {
const [value, setValue] = useState('');
const [isComposing, setIsComposing] = useState(false);

const disabled = !connected || loading || askUser?.spec.type === 'file';
const disabled =
!connected ||
loading ||
askUser?.spec.type === 'file' ||
askUser?.spec.type === 'action';

useEffect(() => {
if (ref.current && !loading && !disabled) {
Expand Down
19 changes: 13 additions & 6 deletions libs/components/src/messages/components/ActionButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@ interface ActionProps {
}

const ActionButton = ({ action, margin, onClick }: ActionProps) => {
const { loading } = useContext(MessageContext);
const { askUser, loading } = useContext(MessageContext);
const isAskingAction = askUser?.spec.type === 'action';
const isDisabled = isAskingAction && !askUser?.spec.keys?.includes(action.id);
const handleClick = () => {
if (isAskingAction) {
askUser?.callback(action);
} else {
action.onClick();
onClick?.();
}
};

return (
<Tooltip title={action.description} placement="top">
Expand All @@ -25,11 +35,8 @@ const ActionButton = ({ action, margin, onClick }: ActionProps) => {
margin
}}
id={action.id}
onClick={() => {
action.onClick();
onClick?.();
}}
disabled={loading}
onClick={handleClick}
disabled={loading || isDisabled}
>
{action.label || action.name}
</Button>
Expand Down
12 changes: 9 additions & 3 deletions libs/components/src/types/file.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { IAction } from './action';
import { IMessage } from './message';

export interface FileSpec {
Expand All @@ -6,6 +7,10 @@ export interface FileSpec {
max_files?: number;
}

export interface ActionSpec {
keys?: string[];
}

export interface IFileResponse {
name: string;
path?: string;
Expand All @@ -15,9 +20,10 @@ export interface IFileResponse {
}

export interface IAsk {
callback: (payload: IMessage | IFileResponse[]) => void;
callback: (payload: IMessage | IFileResponse[] | IAction) => void;
spec: {
type: 'text' | 'file';
type: 'text' | 'file' | 'action';
timeout: number;
} & FileSpec;
} & FileSpec &
ActionSpec;
}

0 comments on commit 9fd2804

Please sign in to comment.