Skip to content

Commit

Permalink
feat: Add D&D file as attachments (#474)
Browse files Browse the repository at this point in the history
* feat: Add file dropping on chat

---------

Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
  • Loading branch information
alimtunc and willydouhard authored Oct 16, 2023
1 parent 9fd2804 commit fff98d5
Show file tree
Hide file tree
Showing 53 changed files with 803 additions and 506 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import chainlit as cl


@cl.on_message # this function will be called every time a user inputs a message in the UI
async def main(message: str, message_id: str):
async def main(message: cl.Message):
"""
This function is called every time a user inputs a message in the UI.
It sends back an intermediate response from Tool 1, followed by the final answer.
Expand All @@ -64,7 +64,7 @@ async def main(message: str, message_id: str):
await cl.Message(
author="Tool 1",
content=f"Response from tool1",
parent_id=message_id,
parent_id=message.id,
).send()

# Send the final answer.
Expand Down
33 changes: 1 addition & 32 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def on_message(func: Callable) -> Callable:
The decorated function is called every time a new message is received.
Args:
func (Callable[[str, str], Any]): The function to be called when a new message is received. Takes the input message and the message id.
func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message.
Returns:
Callable[[str], Any]: The decorated on_message function.
Expand Down Expand Up @@ -253,37 +253,6 @@ def on_settings_update(
return func


def on_file_upload(
accept: Union[List[str], Dict[str, List[str]]],
max_size_mb: int = 2,
max_files: int = 1,
) -> Callable:
"""
A decorator designed for handling spontaneously uploaded files.
This decorator is intended to be used with files that are uploaded on-the-fly.
Args:
accept (Union[List[str], Dict[str, List[str]]]): A list of accepted file extensions or a dictionary of extension lists per field.
type (Optional[str]): The type of upload, defaults to "file".
max_size_mb (Optional[int]): The maximum file size in megabytes, defaults to 2.
max_files (Optional[int]): The maximum number of files allowed to be uploaded, defaults to 1.
Returns:
Callable: The decorated function for handling spontaneous file uploads.
"""

def decorator(func: Callable) -> Callable:
config.code.on_file_upload_config = FileSpec(
accept=accept,
max_size_mb=max_size_mb,
max_files=max_files,
)
config.code.on_file_upload = wrap_user_function(func)
return func

return decorator


def sleep(duration: int):
"""
Sleep for a given duration.
Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ElementDict(TypedDict):
size: Optional[ElementSize]
language: Optional[str]
forIds: Optional[List[str]]
mime: Optional[str]


class ConversationDict(TypedDict):
Expand Down
9 changes: 6 additions & 3 deletions backend/chainlit/client/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ async def get_conversation(self, conversation_id: str) -> ConversationDict:
conversationId
type
name
mime
url
display
language
Expand Down Expand Up @@ -370,6 +371,7 @@ async def get_element(
conversationId
type
name
mime
url
display
language
Expand All @@ -389,8 +391,8 @@ async def get_element(

async def create_element(self, variables: ElementDict) -> Optional[ElementDict]:
mutation = """
mutation ($conversationId: ID!, $type: String!, $name: String!, $display: String!, $forIds: [String!]!, $url: String, $objectKey: String, $size: String, $language: String) {
createElement(conversationId: $conversationId, type: $type, url: $url, objectKey: $objectKey, name: $name, display: $display, size: $size, language: $language, forIds: $forIds) {
mutation ($conversationId: ID!, $type: String!, $name: String!, $display: String!, $forIds: [String!]!, $url: String, $objectKey: String, $size: String, $language: String, $mime: String) {
createElement(conversationId: $conversationId, type: $type, url: $url, objectKey: $objectKey, name: $name, display: $display, size: $size, language: $language, forIds: $forIds, mime: $mime) {
id,
type,
url,
Expand All @@ -399,7 +401,8 @@ async def create_element(self, variables: ElementDict) -> Optional[ElementDict]:
display,
size,
language,
forIds
forIds,
mime
}
}
"""
Expand Down
5 changes: 4 additions & 1 deletion backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
# Show the prompt playground
prompt_playground = true
# Authorize users to upload files with messages
multi_modal = true
[UI]
# Name of the app and chatbot.
name = "Chatbot"
Expand Down Expand Up @@ -138,6 +141,7 @@ class Theme(DataClassJsonMixin):
@dataclass()
class FeaturesSettings(DataClassJsonMixin):
prompt_playground: bool = True
multi_modal: bool = True


@dataclass()
Expand Down Expand Up @@ -170,7 +174,6 @@ class CodeSettings:
on_chat_start: Optional[Callable[[], Any]] = None
on_chat_end: Optional[Callable[[], Any]] = None
on_message: Optional[Callable[[str], Any]] = None
on_file_upload: Optional[Callable[[str], Any]] = None
author_rename: Optional[Callable[[str], str]] = None
on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None
set_chat_profiles: Optional[
Expand Down
44 changes: 37 additions & 7 deletions backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class Element:
for_ids: List[str] = Field(default_factory=list)
# The language, if relevant
language: Optional[str] = None
# Mime type, infered based on content if not provided
mime: Optional[str] = None

def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
Expand All @@ -66,11 +68,35 @@ def to_dict(self) -> ElementDict:
"size": getattr(self, "size", None),
"language": getattr(self, "language", None),
"forIds": getattr(self, "for_ids", None),
"mime": getattr(self, "mime", None),
"conversationId": None,
}
)
return _dict

@classmethod
def from_dict(self, _dict: Dict):
if "image" in _dict.get("mime", ""):
return Image(
id=_dict.get("id", str(uuid.uuid4())),
content=_dict.get("content"),
name=_dict.get("name"),
url=_dict.get("url"),
display=_dict.get("display", "inline"),
mime=_dict.get("mime"),
)
else:
return File(
id=_dict.get("id", str(uuid.uuid4())),
content=_dict.get("content"),
name=_dict.get("name"),
url=_dict.get("url"),
language=_dict.get("language"),
display=_dict.get("display", "inline"),
size=_dict.get("size"),
mime=_dict.get("mime"),
)

async def with_conversation_id(self):
_dict = self.to_dict()
_dict["conversationId"] = await context.session.get_conversation_id()
Expand All @@ -88,15 +114,11 @@ async def load(self):

async def persist(self, client: ChainlitCloudClient) -> Optional[ElementDict]:
if not self.url and self.content and not self.persisted:
# Only guess the mime type when the content is binary
mime = (
mime_types[self.type]
if self.type in mime_types
else filetype.guess_mime(self.content)
)
conversation_id = await context.session.get_conversation_id()
upload_res = await client.upload_element(
content=self.content, mime=mime, conversation_id=conversation_id
content=self.content,
mime=self.mime or "",
conversation_id=conversation_id,
)
self.url = upload_res["url"]
self.object_key = upload_res["object_key"]
Expand Down Expand Up @@ -132,6 +154,14 @@ async def send(self, for_id: Optional[str] = None):

await self.preprocess_content()

if not self.mime:
# Only guess the mime type when the content is binary
self.mime = (
mime_types[self.type]
if self.type in mime_types
else filetype.guess_mime(self.content)
)

if for_id and for_id not in self.for_ids:
self.for_ids.append(for_id)

Expand Down
50 changes: 25 additions & 25 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import uuid
from typing import Any, Dict, cast
from typing import Any, Dict, Optional

from chainlit.client.base import MessageDict
from chainlit.data import chainlit_client
from chainlit.element import Element
from chainlit.message import Message
from chainlit.session import BaseSession, WebsocketSession
from chainlit.types import AskSpec, FileSpec
from chainlit.types import AskSpec, UIMessagePayload
from socketio.exceptions import TimeoutError


Expand Down Expand Up @@ -49,17 +51,13 @@ async def clear_ask(self):
"""Stub method to clear the prompt from the UI."""
pass

async def enable_file_upload(self, spec):
"""Stub method to enable uploading file in the UI."""
pass

async def init_conversation(self, msg_dict: MessageDict):
"""Signal the UI that a new conversation (with a user message) exists"""
pass

async def process_user_message(self, message_dict) -> None:
async def process_user_message(self, payload: UIMessagePayload) -> Message:
"""Stub method to process user message."""
pass
return Message(content="")

async def send_ask_user(self, msg_dict: dict, spec, raise_on_timeout=False):
"""Stub method to send a prompt to the UI and wait for a response."""
Expand Down Expand Up @@ -146,37 +144,39 @@ def clear_ask(self):

return self.emit("clear_ask", {})

def enable_file_upload(self, spec: FileSpec):
"""Enable uploading file in the UI."""

return self.emit("enable_file_upload", spec.to_dict())

def init_conversation(self, message: MessageDict):
"""Signal the UI that a new conversation (with a user message) exists"""

return self.emit("init_conversation", message)

async def process_user_message(self, message_dict: MessageDict):
async def process_user_message(self, payload: UIMessagePayload):
message_dict = payload["message"]
files = payload["files"]
# Temporary UUID generated by the frontend should use v4
assert uuid.UUID(message_dict["id"]).version == 4

if chainlit_client:
message_dict["conversationId"] = await self.session.get_conversation_id()
# We have to update the UI with the actual DB ID
ui_message_update = cast(Dict, message_dict.copy())
persisted_id = await chainlit_client.create_message(message_dict)
if persisted_id:
message_dict["id"] = persisted_id
ui_message_update["newId"] = message_dict["id"]
await self.update_message(ui_message_update)
message = Message.from_dict(message_dict)

asyncio.create_task(message._create())

if files:
file_elements = [Element.from_dict(file) for file in files]
message.elements = file_elements

async def send_elements():
for element in message.elements:
await element.send(for_id=message.id)

asyncio.create_task(send_elements())

if not self.session.has_user_message:
self.session.has_user_message = True
await self.init_conversation(message_dict)
await self.init_conversation(await message.with_conversation_id())

message = Message.from_dict(message_dict)
self.session.root_message = message

return message

async def send_ask_user(
self, msg_dict: Dict, spec: AskSpec, raise_on_timeout=False
):
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ async def send(self) -> str:
context.session.root_message = self

for action in self.actions:
await action.send(for_id=str(id))
await action.send(for_id=id)

for element in self.elements:
await element.send(for_id=str(id))
await element.send(for_id=id)

return id

Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ async def project_settings(
return JSONResponse(
content={
"ui": config.ui.to_dict(),
"features": config.features.to_dict(),
"userEnv": config.project.user_env,
"dataPersistence": config.data_persistence,
"markdown": get_markdown_str(config.root),
Expand Down
5 changes: 2 additions & 3 deletions backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ async def get_conversation_id(self) -> Optional[str]:
else:
tags = ["chat"]

if not self.conversation_id:
async with self.lock:
async with self.lock:
if not self.conversation_id:
app_user_id = (
self.user.id if isinstance(self.user, PersistedAppUser) else None
)
self.conversation_id = await chainlit_client.create_conversation(
app_user_id=app_user_id, tags=tags
)

return self.conversation_id


Expand Down
Loading

0 comments on commit fff98d5

Please sign in to comment.