From fff98d5f2fff7344e31f2f3259fdb47e56884eb3 Mon Sep 17 00:00:00 2001 From: SuperTurk Date: Mon, 16 Oct 2023 09:46:38 +0200 Subject: [PATCH] feat: Add D&D file as attachments (#474) * feat: Add file dropping on chat --------- Co-authored-by: Willy Douhard --- README.md | 4 +- backend/chainlit/__init__.py | 33 +-- backend/chainlit/client/base.py | 1 + backend/chainlit/client/cloud.py | 9 +- backend/chainlit/config.py | 5 +- backend/chainlit/element.py | 44 +++- backend/chainlit/emitter.py | 50 ++--- backend/chainlit/message.py | 4 +- backend/chainlit/server.py | 1 + backend/chainlit/session.py | 5 +- backend/chainlit/socket.py | 24 +-- backend/chainlit/types.py | 8 +- cypress/e2e/ask_file/spec.cy.ts | 42 ++-- cypress/e2e/ask_multiple_files/spec.cy.ts | 26 +-- cypress/e2e/avatar/spec.cy.ts | 2 +- cypress/e2e/cot/main.py | 6 +- cypress/e2e/file_element/main.py | 4 + cypress/e2e/file_element/spec.cy.ts | 12 +- cypress/e2e/file_upload/.chainlit/config.toml | 62 ------ cypress/e2e/file_upload/main.py | 14 -- cypress/e2e/file_upload/spec.cy.ts | 31 --- cypress/e2e/stop_task/main_async.py | 2 +- cypress/e2e/stop_task/main_sync.py | 2 +- .../.chainlit/config.toml | 6 +- cypress/e2e/upload_attachments/main.py | 10 + cypress/e2e/upload_attachments/spec.cy.ts | 72 +++++++ cypress/e2e/user_session/main.py | 4 +- .../organisms/chat/history/index.tsx | 4 +- .../src/components/organisms/chat/index.tsx | 99 ++++++++- .../organisms/chat/inputBox/UploadButton.tsx | 51 +++++ .../organisms/chat/inputBox/index.tsx | 23 ++- .../organisms/chat/inputBox/input.tsx | 194 +++++++++++++----- .../organisms/chat/message/UploadButton.tsx | 47 ----- .../organisms/chat/message/container.tsx | 3 +- frontend/src/state/chat.ts | 8 + frontend/src/state/project.ts | 3 + frontend/src/types/chat.ts | 3 +- frontend/src/types/element.ts | 70 ------- libs/components/contexts/MessageContext.tsx | 1 + libs/components/hooks/useChat/index.ts | 28 +-- libs/components/hooks/useChat/state.ts | 6 - libs/components/hooks/useUpload.tsx | 16 +- libs/components/package.json | 4 +- libs/components/pnpm-lock.yaml | 33 +++ libs/components/src/Attachments.tsx | 48 +++++ libs/components/src/elements/File.tsx | 139 +++++++++++-- libs/components/src/index.ts | 7 +- libs/components/src/messages/Message.tsx | 2 +- .../messages/components/AskUploadButton.tsx | 21 +- .../src/messages/components/ElementRef.tsx | 1 + libs/components/src/types/element.ts | 1 + libs/components/src/types/message.ts | 13 +- libs/components/src/types/messageContext.ts | 1 + 53 files changed, 803 insertions(+), 506 deletions(-) delete mode 100644 cypress/e2e/file_upload/.chainlit/config.toml delete mode 100644 cypress/e2e/file_upload/main.py delete mode 100644 cypress/e2e/file_upload/spec.cy.ts rename cypress/e2e/{ask_multiple_files => upload_attachments}/.chainlit/config.toml (87%) create mode 100644 cypress/e2e/upload_attachments/main.py create mode 100644 cypress/e2e/upload_attachments/spec.cy.ts create mode 100644 frontend/src/components/organisms/chat/inputBox/UploadButton.tsx delete mode 100644 frontend/src/components/organisms/chat/message/UploadButton.tsx create mode 100644 frontend/src/state/chat.ts delete mode 100644 frontend/src/types/element.ts create mode 100644 libs/components/src/Attachments.tsx diff --git a/README.md b/README.md index fea5bb62fd..325d1996b8 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ import chainlit as cl @cl.on_message # this function will be called every time a user inputs a message in the UI -async def main(message: str, message_id: str): +async def main(message: cl.Message): """ This function is called every time a user inputs a message in the UI. It sends back an intermediate response from Tool 1, followed by the final answer. @@ -64,7 +64,7 @@ async def main(message: str, message_id: str): await cl.Message( author="Tool 1", content=f"Response from tool1", - parent_id=message_id, + parent_id=message.id, ).send() # Send the final answer. diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index ab809fe775..d46599e0d4 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -130,7 +130,7 @@ def on_message(func: Callable) -> Callable: The decorated function is called every time a new message is received. Args: - func (Callable[[str, str], Any]): The function to be called when a new message is received. Takes the input message and the message id. + func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message. Returns: Callable[[str], Any]: The decorated on_message function. @@ -253,37 +253,6 @@ def on_settings_update( return func -def on_file_upload( - accept: Union[List[str], Dict[str, List[str]]], - max_size_mb: int = 2, - max_files: int = 1, -) -> Callable: - """ - A decorator designed for handling spontaneously uploaded files. - This decorator is intended to be used with files that are uploaded on-the-fly. - - Args: - accept (Union[List[str], Dict[str, List[str]]]): A list of accepted file extensions or a dictionary of extension lists per field. - type (Optional[str]): The type of upload, defaults to "file". - max_size_mb (Optional[int]): The maximum file size in megabytes, defaults to 2. - max_files (Optional[int]): The maximum number of files allowed to be uploaded, defaults to 1. - - Returns: - Callable: The decorated function for handling spontaneous file uploads. - """ - - def decorator(func: Callable) -> Callable: - config.code.on_file_upload_config = FileSpec( - accept=accept, - max_size_mb=max_size_mb, - max_files=max_files, - ) - config.code.on_file_upload = wrap_user_function(func) - return func - - return decorator - - def sleep(duration: int): """ Sleep for a given duration. diff --git a/backend/chainlit/client/base.py b/backend/chainlit/client/base.py index 83dfac419f..338231bb0e 100644 --- a/backend/chainlit/client/base.py +++ b/backend/chainlit/client/base.py @@ -78,6 +78,7 @@ class ElementDict(TypedDict): size: Optional[ElementSize] language: Optional[str] forIds: Optional[List[str]] + mime: Optional[str] class ConversationDict(TypedDict): diff --git a/backend/chainlit/client/cloud.py b/backend/chainlit/client/cloud.py index f04d102452..28c875030f 100644 --- a/backend/chainlit/client/cloud.py +++ b/backend/chainlit/client/cloud.py @@ -201,6 +201,7 @@ async def get_conversation(self, conversation_id: str) -> ConversationDict: conversationId type name + mime url display language @@ -370,6 +371,7 @@ async def get_element( conversationId type name + mime url display language @@ -389,8 +391,8 @@ async def get_element( async def create_element(self, variables: ElementDict) -> Optional[ElementDict]: mutation = """ - mutation ($conversationId: ID!, $type: String!, $name: String!, $display: String!, $forIds: [String!]!, $url: String, $objectKey: String, $size: String, $language: String) { - createElement(conversationId: $conversationId, type: $type, url: $url, objectKey: $objectKey, name: $name, display: $display, size: $size, language: $language, forIds: $forIds) { + mutation ($conversationId: ID!, $type: String!, $name: String!, $display: String!, $forIds: [String!]!, $url: String, $objectKey: String, $size: String, $language: String, $mime: String) { + createElement(conversationId: $conversationId, type: $type, url: $url, objectKey: $objectKey, name: $name, display: $display, size: $size, language: $language, forIds: $forIds, mime: $mime) { id, type, url, @@ -399,7 +401,8 @@ async def create_element(self, variables: ElementDict) -> Optional[ElementDict]: display, size, language, - forIds + forIds, + mime } } """ diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 503c250762..846b343b04 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -47,6 +47,9 @@ # Show the prompt playground prompt_playground = true +# Authorize users to upload files with messages +multi_modal = true + [UI] # Name of the app and chatbot. name = "Chatbot" @@ -138,6 +141,7 @@ class Theme(DataClassJsonMixin): @dataclass() class FeaturesSettings(DataClassJsonMixin): prompt_playground: bool = True + multi_modal: bool = True @dataclass() @@ -170,7 +174,6 @@ class CodeSettings: on_chat_start: Optional[Callable[[], Any]] = None on_chat_end: Optional[Callable[[], Any]] = None on_message: Optional[Callable[[str], Any]] = None - on_file_upload: Optional[Callable[[str], Any]] = None author_rename: Optional[Callable[[str], str]] = None on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None set_chat_profiles: Optional[ diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 12bb89901b..50516b8b93 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -46,6 +46,8 @@ class Element: for_ids: List[str] = Field(default_factory=list) # The language, if relevant language: Optional[str] = None + # Mime type, infered based on content if not provided + mime: Optional[str] = None def __post_init__(self) -> None: trace_event(f"init {self.__class__.__name__}") @@ -66,11 +68,35 @@ def to_dict(self) -> ElementDict: "size": getattr(self, "size", None), "language": getattr(self, "language", None), "forIds": getattr(self, "for_ids", None), + "mime": getattr(self, "mime", None), "conversationId": None, } ) return _dict + @classmethod + def from_dict(self, _dict: Dict): + if "image" in _dict.get("mime", ""): + return Image( + id=_dict.get("id", str(uuid.uuid4())), + content=_dict.get("content"), + name=_dict.get("name"), + url=_dict.get("url"), + display=_dict.get("display", "inline"), + mime=_dict.get("mime"), + ) + else: + return File( + id=_dict.get("id", str(uuid.uuid4())), + content=_dict.get("content"), + name=_dict.get("name"), + url=_dict.get("url"), + language=_dict.get("language"), + display=_dict.get("display", "inline"), + size=_dict.get("size"), + mime=_dict.get("mime"), + ) + async def with_conversation_id(self): _dict = self.to_dict() _dict["conversationId"] = await context.session.get_conversation_id() @@ -88,15 +114,11 @@ async def load(self): async def persist(self, client: ChainlitCloudClient) -> Optional[ElementDict]: if not self.url and self.content and not self.persisted: - # Only guess the mime type when the content is binary - mime = ( - mime_types[self.type] - if self.type in mime_types - else filetype.guess_mime(self.content) - ) conversation_id = await context.session.get_conversation_id() upload_res = await client.upload_element( - content=self.content, mime=mime, conversation_id=conversation_id + content=self.content, + mime=self.mime or "", + conversation_id=conversation_id, ) self.url = upload_res["url"] self.object_key = upload_res["object_key"] @@ -132,6 +154,14 @@ async def send(self, for_id: Optional[str] = None): await self.preprocess_content() + if not self.mime: + # Only guess the mime type when the content is binary + self.mime = ( + mime_types[self.type] + if self.type in mime_types + else filetype.guess_mime(self.content) + ) + if for_id and for_id not in self.for_ids: self.for_ids.append(for_id) diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index 00b1723965..45189491b6 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -1,11 +1,13 @@ +import asyncio import uuid -from typing import Any, Dict, cast +from typing import Any, Dict, Optional from chainlit.client.base import MessageDict from chainlit.data import chainlit_client +from chainlit.element import Element from chainlit.message import Message from chainlit.session import BaseSession, WebsocketSession -from chainlit.types import AskSpec, FileSpec +from chainlit.types import AskSpec, UIMessagePayload from socketio.exceptions import TimeoutError @@ -49,17 +51,13 @@ async def clear_ask(self): """Stub method to clear the prompt from the UI.""" pass - async def enable_file_upload(self, spec): - """Stub method to enable uploading file in the UI.""" - pass - async def init_conversation(self, msg_dict: MessageDict): """Signal the UI that a new conversation (with a user message) exists""" pass - async def process_user_message(self, message_dict) -> None: + async def process_user_message(self, payload: UIMessagePayload) -> Message: """Stub method to process user message.""" - pass + return Message(content="") async def send_ask_user(self, msg_dict: dict, spec, raise_on_timeout=False): """Stub method to send a prompt to the UI and wait for a response.""" @@ -146,37 +144,39 @@ def clear_ask(self): return self.emit("clear_ask", {}) - def enable_file_upload(self, spec: FileSpec): - """Enable uploading file in the UI.""" - - return self.emit("enable_file_upload", spec.to_dict()) - def init_conversation(self, message: MessageDict): """Signal the UI that a new conversation (with a user message) exists""" return self.emit("init_conversation", message) - async def process_user_message(self, message_dict: MessageDict): + async def process_user_message(self, payload: UIMessagePayload): + message_dict = payload["message"] + files = payload["files"] # Temporary UUID generated by the frontend should use v4 assert uuid.UUID(message_dict["id"]).version == 4 - if chainlit_client: - message_dict["conversationId"] = await self.session.get_conversation_id() - # We have to update the UI with the actual DB ID - ui_message_update = cast(Dict, message_dict.copy()) - persisted_id = await chainlit_client.create_message(message_dict) - if persisted_id: - message_dict["id"] = persisted_id - ui_message_update["newId"] = message_dict["id"] - await self.update_message(ui_message_update) + message = Message.from_dict(message_dict) + + asyncio.create_task(message._create()) + + if files: + file_elements = [Element.from_dict(file) for file in files] + message.elements = file_elements + + async def send_elements(): + for element in message.elements: + await element.send(for_id=message.id) + + asyncio.create_task(send_elements()) if not self.session.has_user_message: self.session.has_user_message = True - await self.init_conversation(message_dict) + await self.init_conversation(await message.with_conversation_id()) - message = Message.from_dict(message_dict) self.session.root_message = message + return message + async def send_ask_user( self, msg_dict: Dict, spec: AskSpec, raise_on_timeout=False ): diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index a678575196..893efdbdb5 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -267,10 +267,10 @@ async def send(self) -> str: context.session.root_message = self for action in self.actions: - await action.send(for_id=str(id)) + await action.send(for_id=id) for element in self.elements: - await element.send(for_id=str(id)) + await element.send(for_id=id) return id diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 48e5c16c09..28c80aea35 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -448,6 +448,7 @@ async def project_settings( return JSONResponse( content={ "ui": config.ui.to_dict(), + "features": config.features.to_dict(), "userEnv": config.project.user_env, "dataPersistence": config.data_persistence, "markdown": get_markdown_str(config.root), diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index 02bb806ab8..e0d5aba160 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -50,15 +50,14 @@ async def get_conversation_id(self) -> Optional[str]: else: tags = ["chat"] - if not self.conversation_id: - async with self.lock: + async with self.lock: + if not self.conversation_id: app_user_id = ( self.user.id if isinstance(self.user, PersistedAppUser) else None ) self.conversation_id = await chainlit_client.create_conversation( app_user_id=app_user_id, tags=tags ) - return self.conversation_id diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 718e3863c4..8cf9a0685a 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -13,6 +13,7 @@ from chainlit.server import socket from chainlit.session import WebsocketSession from chainlit.telemetry import trace_event +from chainlit.types import UIMessagePayload from chainlit.user_session import user_sessions @@ -106,9 +107,6 @@ async def connection_successful(sid): if context.session.restored: return - if config.code.on_file_upload: - await context.emitter.enable_file_upload(config.code.on_file_upload_config) - if config.code.on_chat_start: """Call the on_chat_start function provided by the developer.""" await config.code.on_chat_start() @@ -166,16 +164,15 @@ async def stop(sid): await config.code.on_stop() -async def process_message(session: WebsocketSession, message_dict: MessageDict): +async def process_message(session: WebsocketSession, payload: UIMessagePayload): """Process a message from the user.""" try: context = init_ws_context(session) await context.emitter.task_start() if config.code.on_message: - await context.emitter.process_user_message(message_dict) - message = Message.from_dict(message_dict) - await config.code.on_message(message.content.strip(), message.id) + message = await context.emitter.process_user_message(payload) + await config.code.on_message(message) except InterruptedError: pass except Exception as e: @@ -188,12 +185,12 @@ async def process_message(session: WebsocketSession, message_dict: MessageDict): @socket.on("ui_message") -async def message(sid, message): +async def message(sid, payload: UIMessagePayload): """Handle a message sent by the User.""" session = WebsocketSession.require(sid) session.should_stop = False - await process_message(session, message) + await process_message(session, payload) async def process_action(action: Action): @@ -224,12 +221,3 @@ async def change_settings(sid, settings: Dict[str, Any]): if config.code.on_settings_update: await config.code.on_settings_update(settings) - - -@socket.on("file_upload") -async def file_upload(sid, files: Any): - """Handle file upload from the UI.""" - init_ws_context(sid) - - if config.code.on_file_upload: - await config.code.on_file_upload(files) diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index fe9e704f02..da394d0715 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Dict, List, Literal, Optional, TypedDict, Union -from chainlit.client.base import ConversationFilter, Pagination +from chainlit.client.base import ConversationFilter, MessageDict, Pagination +from chainlit.element import File from chainlit.prompt import Prompt from dataclasses_json import DataClassJsonMixin from pydantic import BaseModel @@ -47,6 +48,11 @@ class AskResponse(TypedDict): author: str +class UIMessagePayload(TypedDict): + message: MessageDict + files: Optional[List[Dict]] + + @dataclass class AskFileResponse: name: str diff --git a/cypress/e2e/ask_file/spec.cy.ts b/cypress/e2e/ask_file/spec.cy.ts index 51556b981d..62a041b3d8 100644 --- a/cypress/e2e/ask_file/spec.cy.ts +++ b/cypress/e2e/ask_file/spec.cy.ts @@ -1,45 +1,45 @@ -import { runTestServer } from "../../support/testUtils"; +import { runTestServer } from '../../support/testUtils'; -describe("Upload file", () => { +describe('Upload file', () => { before(() => { runTestServer(); }); - it("should be able to receive and decode files", () => { - cy.get("#upload-button").should("exist"); + it('should be able to receive and decode files', () => { + cy.get('#ask-upload-button').should('exist'); // Upload a text file - cy.fixture("state_of_the_union.txt", "utf-8").as("txtFile"); - cy.get("input[type=file]").selectFile("@txtFile", { force: true }); + cy.fixture('state_of_the_union.txt', 'utf-8').as('txtFile'); + cy.get('#ask-button-input').selectFile('@txtFile', { force: true }); // Sometimes the loading indicator is not shown because the file upload is too fast - // cy.get("#upload-button-loading").should("exist"); - // cy.get("#upload-button-loading").should("not.exist"); + // cy.get("#ask-upload-button-loading").should("exist"); + // cy.get("#ask-upload-button-loading").should("not.exist"); - cy.get(".message") + cy.get('.message') .eq(1) .should( - "contain", - "Text file state_of_the_union.txt uploaded, it contains" + 'contain', + 'Text file state_of_the_union.txt uploaded, it contains' ); - cy.get("#upload-button").should("exist"); + cy.get('#ask-upload-button').should('exist'); // Expecting a python file, cpp file upload should be rejected - cy.fixture("hello.cpp", "utf-8").as("cppFile"); - cy.get("input[type=file]").selectFile("@cppFile", { force: true }); + cy.fixture('hello.cpp', 'utf-8').as('cppFile'); + cy.get('#ask-button-input').selectFile('@cppFile', { force: true }); - cy.get(".message").should("have.length", 3); + cy.get('.message').should('have.length', 3); // Upload a python file - cy.fixture("hello.py", "utf-8").as("pyFile"); - cy.get("input[type=file]").selectFile("@pyFile", { force: true }); + cy.fixture('hello.py', 'utf-8').as('pyFile'); + cy.get('#ask-button-input').selectFile('@pyFile', { force: true }); - cy.get(".message") - .should("have.length", 4) + cy.get('.message') + .should('have.length', 4) .eq(3) - .should("contain", "Python file hello.py uploaded, it contains"); + .should('contain', 'Python file hello.py uploaded, it contains'); - cy.get("#upload-button").should("not.exist"); + cy.get('#ask-upload-button').should('not.exist'); }); }); diff --git a/cypress/e2e/ask_multiple_files/spec.cy.ts b/cypress/e2e/ask_multiple_files/spec.cy.ts index aa926b456b..bb40cc5e19 100644 --- a/cypress/e2e/ask_multiple_files/spec.cy.ts +++ b/cypress/e2e/ask_multiple_files/spec.cy.ts @@ -1,28 +1,28 @@ -import { runTestServer } from "../../support/testUtils"; +import { runTestServer } from '../../support/testUtils'; -describe("Upload multiple files", () => { +describe('Upload multiple files', () => { before(() => { runTestServer(); }); - it("should be able to receive two files", () => { - cy.get("#upload-button").should("exist"); + it('should be able to receive two files', () => { + cy.get('#ask-upload-button').should('exist'); - cy.fixture("state_of_the_union.txt", "utf-8").as("txtFile"); - cy.fixture("hello.py", "utf-8").as("pyFile"); + cy.fixture('state_of_the_union.txt', 'utf-8').as('txtFile'); + cy.fixture('hello.py', 'utf-8').as('pyFile'); - cy.get("input[type=file]").selectFile(["@txtFile", "@pyFile"], { - force: true, + cy.get('#ask-button-input').selectFile(['@txtFile', '@pyFile'], { + force: true }); // Sometimes the loading indicator is not shown because the file upload is too fast - // cy.get("#upload-button-loading").should("exist"); - // cy.get("#upload-button-loading").should("not.exist"); + // cy.get("#ask-upload-button-loading").should("exist"); + // cy.get("#ask-upload-button-loading").should("not.exist"); - cy.get(".message") + cy.get('.message') .eq(1) - .should("contain", "2 files uploaded: state_of_the_union.txt,hello.py"); + .should('contain', '2 files uploaded: state_of_the_union.txt,hello.py'); - cy.get("#upload-button").should("not.exist"); + cy.get('#ask-upload-button').should('not.exist'); }); }); diff --git a/cypress/e2e/avatar/spec.cy.ts b/cypress/e2e/avatar/spec.cy.ts index 40b509c141..1a43e55c6a 100644 --- a/cypress/e2e/avatar/spec.cy.ts +++ b/cypress/e2e/avatar/spec.cy.ts @@ -5,7 +5,7 @@ describe('Avatar', () => { runTestServer(); }); - it('should be able to display a nested CoT', () => { + it('should be able to display avatars', () => { cy.get('.message').should('have.length', 3); cy.get('.message').eq(0).find('.message-avatar').should('have.length', 0); diff --git a/cypress/e2e/cot/main.py b/cypress/e2e/cot/main.py index 279e69eae2..b8665e5c12 100644 --- a/cypress/e2e/cot/main.py +++ b/cypress/e2e/cot/main.py @@ -2,8 +2,8 @@ @cl.on_message -async def main(message: str, message_id: str): - tool1_msg = cl.Message(content="", author="Tool 1", parent_id=message_id) +async def main(message: cl.Message): + tool1_msg = cl.Message(content="", author="Tool 1", parent_id=message.id) await tool1_msg.send() await cl.sleep(1) @@ -23,7 +23,7 @@ async def main(message: str, message_id: str): await tool2_msg.update() await cl.Message( - content="Response from tool 2", author="Tool 1", parent_id=message_id + content="Response from tool 2", author="Tool 1", parent_id=message.id ).send() await cl.Message( diff --git a/cypress/e2e/file_element/main.py b/cypress/e2e/file_element/main.py index b3b01860b4..b1254e36d2 100644 --- a/cypress/e2e/file_element/main.py +++ b/cypress/e2e/file_element/main.py @@ -8,21 +8,25 @@ async def start(): name="example.mp4", path="../../fixtures/example.mp4", display="inline", + mime="video/mp4", ), cl.File( name="cat.jpeg", path="../../fixtures/cat.jpeg", display="inline", + mime="image/jpg", ), cl.File( name="hello.py", path="../../fixtures/hello.py", display="inline", + mime="plain/py", ), cl.File( name="example.mp3", path="../../fixtures/example.mp3", display="inline", + mime="audio/mp3", ), ] diff --git a/cypress/e2e/file_element/spec.cy.ts b/cypress/e2e/file_element/spec.cy.ts index a85f85b5f2..c661e5fd94 100644 --- a/cypress/e2e/file_element/spec.cy.ts +++ b/cypress/e2e/file_element/spec.cy.ts @@ -9,13 +9,9 @@ describe('file', () => { cy.get('.message').should('have.length', 1); cy.get('.message').eq(0).find('.inline-file').should('have.length', 4); - cy.get('a.inline-file') - .eq(0) - .should('have.attr', 'download', 'example.mp4'); - cy.get('a.inline-file').eq(1).should('have.attr', 'download', 'cat.jpeg'); - cy.get('a.inline-file').eq(2).should('have.attr', 'download', 'hello.py'); - cy.get('a.inline-file') - .eq(3) - .should('have.attr', 'download', 'example.mp3'); + cy.get('.inline-file').eq(0).should('have.attr', 'download', 'example.mp4'); + cy.get('.inline-file').eq(1).should('have.attr', 'download', 'cat.jpeg'); + cy.get('.inline-file').eq(2).should('have.attr', 'download', 'hello.py'); + cy.get('.inline-file').eq(3).should('have.attr', 'download', 'example.mp3'); }); }); diff --git a/cypress/e2e/file_upload/.chainlit/config.toml b/cypress/e2e/file_upload/.chainlit/config.toml deleted file mode 100644 index 0c509af72c..0000000000 --- a/cypress/e2e/file_upload/.chainlit/config.toml +++ /dev/null @@ -1,62 +0,0 @@ -[project] -# Whether to enable telemetry (default: true). No personal data is collected. -enable_telemetry = true - -# List of environment variables to be provided by each user to use the app. -user_env = [] - -# Duration (in seconds) during which the session is saved when the connection is lost -session_timeout = 3600 - -# Enable third parties caching (e.g LangChain cache) -cache = false - -# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317) -# follow_symlink = false - -[features] -# Show the prompt playground -prompt_playground = true - -[UI] -# Name of the app and chatbot. -name = "Chatbot" - -# Description of the app and chatbot. This is used for HTML tags. -# description = "" - -# Large size content are by default collapsed for a cleaner ui -default_collapse_content = true - -# The default value for the expand messages settings. -default_expand_messages = false - -# Hide the chain of thought details from the user in the UI. -hide_cot = false - -# Link to your github repo. This will add a github button in the UI's header. -# github = "" - -# Override default MUI light theme. (Check theme.ts) -[UI.theme.light] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.light.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - -# Override default MUI dark theme. (Check theme.ts) -[UI.theme.dark] - #background = "#FAFAFA" - #paper = "#FFFFFF" - - [UI.theme.dark.primary] - #main = "#F80061" - #dark = "#980039" - #light = "#FFE7EB" - - -[meta] -generated_by = "0.6.402" diff --git a/cypress/e2e/file_upload/main.py b/cypress/e2e/file_upload/main.py deleted file mode 100644 index 17e1d0166e..0000000000 --- a/cypress/e2e/file_upload/main.py +++ /dev/null @@ -1,14 +0,0 @@ -import chainlit as cl - - -@cl.on_file_upload(accept={"text/plain": [".py"]}) -async def upload_file(files: any): - for file_data in files: - await cl.Message( - content=f"`{file_data['name']}` uploaded, it contains {len(file_data['content'])} characters!" - ).send() - - -@cl.on_chat_start -async def start(): - await cl.Message(content=f"Try to upload a file").send() diff --git a/cypress/e2e/file_upload/spec.cy.ts b/cypress/e2e/file_upload/spec.cy.ts deleted file mode 100644 index 7e6e824520..0000000000 --- a/cypress/e2e/file_upload/spec.cy.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { runTestServer } from "../../support/testUtils"; - -describe("Upload file", () => { - before(() => { - runTestServer(); - }); - - it("should be able to upload and decode files", () => { - cy.get("#upload-button").should("exist"); - - // Upload a text file - cy.fixture("state_of_the_union.txt", "utf-8").as("txtFile"); - cy.get("input[type=file]").selectFile("@txtFile", { force: true }); - - cy.get(".message") - .eq(1) - .should("contain", "state_of_the_union.txt uploaded, it contains"); - - // Expecting a python file, cpp file upload should be rejected - cy.fixture("hello.cpp", "utf-8").as("cppFile"); - cy.get("input[type=file]").selectFile("@cppFile", { force: true }); - - // Upload a python file - cy.fixture("hello.py", "utf-8").as("pyFile"); - cy.get("input[type=file]").selectFile("@pyFile", { force: true }); - - cy.get(".message") - .eq(2) - .should("contain", "hello.py uploaded, it contains"); - }); -}); diff --git a/cypress/e2e/stop_task/main_async.py b/cypress/e2e/stop_task/main_async.py index fdfabf1dc6..93fba67f43 100644 --- a/cypress/e2e/stop_task/main_async.py +++ b/cypress/e2e/stop_task/main_async.py @@ -9,5 +9,5 @@ async def start(): @cl.on_message -async def message(message: str): +async def message(message: cl.Message): await cl.Message(content="World").send() diff --git a/cypress/e2e/stop_task/main_sync.py b/cypress/e2e/stop_task/main_sync.py index 80cb877ab0..e5cd0fa82d 100644 --- a/cypress/e2e/stop_task/main_sync.py +++ b/cypress/e2e/stop_task/main_sync.py @@ -15,5 +15,5 @@ async def start(): @cl.on_message -async def message(message: str): +async def message(message: cl.Message): await cl.Message(content="World").send() diff --git a/cypress/e2e/ask_multiple_files/.chainlit/config.toml b/cypress/e2e/upload_attachments/.chainlit/config.toml similarity index 87% rename from cypress/e2e/ask_multiple_files/.chainlit/config.toml rename to cypress/e2e/upload_attachments/.chainlit/config.toml index 0c509af72c..dd4a16cec1 100644 --- a/cypress/e2e/ask_multiple_files/.chainlit/config.toml +++ b/cypress/e2e/upload_attachments/.chainlit/config.toml @@ -37,6 +37,10 @@ hide_cot = false # Link to your github repo. This will add a github button in the UI's header. # github = "" +# Specify a CSS file that can be used to customize the user interface. +# The CSS file can be served from the public directory or via an external link. +# custom_css = "/public/test.css" + # Override default MUI light theme. (Check theme.ts) [UI.theme.light] #background = "#FAFAFA" @@ -59,4 +63,4 @@ hide_cot = false [meta] -generated_by = "0.6.402" +generated_by = "0.7.2" diff --git a/cypress/e2e/upload_attachments/main.py b/cypress/e2e/upload_attachments/main.py new file mode 100644 index 0000000000..b065cef9f3 --- /dev/null +++ b/cypress/e2e/upload_attachments/main.py @@ -0,0 +1,10 @@ +import chainlit as cl + + +@cl.on_message +async def main(message: cl.Message): + await cl.Message(content=f"Content: {message.content}").send() + # Check if message.elements is not empty and is a list + for index, item in enumerate(message.elements): + # Send a response for each element + await cl.Message(content=f"Received element {index}: {item.name}").send() diff --git a/cypress/e2e/upload_attachments/spec.cy.ts b/cypress/e2e/upload_attachments/spec.cy.ts new file mode 100644 index 0000000000..7b5f7b3461 --- /dev/null +++ b/cypress/e2e/upload_attachments/spec.cy.ts @@ -0,0 +1,72 @@ +import { runTestServer, submitMessage } from '../../support/testUtils'; + +describe('Upload attachments', () => { + beforeEach(() => { + runTestServer(); + }); + + const shouldHaveInlineAttachments = () => { + submitMessage('Message with attachments'); + cy.get('.message').should('have.length', 5); + cy.get('.message') + .eq(1) + .should('contain', 'Content: Message with attachments'); + cy.get('.message') + .eq(2) + .should('contain', 'Received element 0: state_of_the_union.txt'); + cy.get('.message').eq(3).should('contain', 'Received element 1: hello.cpp'); + cy.get('.message').eq(4).should('contain', 'Received element 2: hello.py'); + + cy.get('.message').eq(0).find('.inline-file').should('have.length', 3); + cy.get('.inline-file') + .eq(0) + .should('have.attr', 'download', 'state_of_the_union.txt'); + cy.get('.inline-file').eq(1).should('have.attr', 'download', 'hello.cpp'); + cy.get('.inline-file').eq(2).should('have.attr', 'download', 'hello.py'); + }; + + it('Should be able to upload file attachments', () => { + cy.fixture('state_of_the_union.txt', 'utf-8').as('txtFile'); + cy.fixture('hello.cpp', 'utf-8').as('cppFile'); + cy.fixture('hello.py', 'utf-8').as('pyFile'); + + /** + * Should be able to upload file from D&D input + */ + cy.get("[id='#upload-drop-input']").should('exist'); + // Upload a text file + cy.get("[id='#upload-drop-input']").selectFile('@txtFile', { force: true }); + // cy.get('#upload-drop-input').selectFile('@txtFile', { force: true }); + cy.get('#attachments').should('contain', 'state_of_the_union.txt'); + + // Upload a C++ file + cy.get("[id='#upload-drop-input']").selectFile('@cppFile', { force: true }); + cy.get('#attachments').should('contain', 'hello.cpp'); + + // Upload a python file + cy.get("[id='#upload-drop-input']").selectFile('@pyFile', { force: true }); + cy.get('#attachments').should('contain', 'hello.py'); + + shouldHaveInlineAttachments(); + + /** + * Should be able to upload file from upload button + */ + cy.reload(); + cy.get('#upload-button').should('exist'); + + // Upload a text file + cy.get('#upload-button-input').selectFile('@txtFile', { force: true }); + cy.get('#attachments').should('contain', 'state_of_the_union.txt'); + + // Upload a C++ file + cy.get('#upload-button-input').selectFile('@cppFile', { force: true }); + cy.get('#attachments').should('contain', 'hello.cpp'); + + // Upload a python file + cy.get('#upload-button-input').selectFile('@pyFile', { force: true }); + cy.get('#attachments').should('contain', 'hello.py'); + + shouldHaveInlineAttachments(); + }); +}); diff --git a/cypress/e2e/user_session/main.py b/cypress/e2e/user_session/main.py index 8c5c1bcbff..b3b722ba9a 100644 --- a/cypress/e2e/user_session/main.py +++ b/cypress/e2e/user_session/main.py @@ -2,7 +2,7 @@ @cl.on_message -async def main(message: str): +async def main(message: cl.Message): prev_msg = cl.user_session.get("prev_msg") await cl.Message(content=f"Prev message: {prev_msg}").send() - cl.user_session.set("prev_msg", message) + cl.user_session.set("prev_msg", message.content) diff --git a/frontend/src/components/organisms/chat/history/index.tsx b/frontend/src/components/organisms/chat/history/index.tsx index e2f0d773eb..f356d870ed 100644 --- a/frontend/src/components/organisms/chat/history/index.tsx +++ b/frontend/src/components/organisms/chat/history/index.tsx @@ -21,6 +21,7 @@ import { chatHistoryState } from 'state/chatHistory'; import { MessageHistory } from 'types/chatHistory'; interface Props { + disabled?: boolean; onClick: (content: string) => void; } @@ -62,7 +63,7 @@ function buildHistory(historyMessages: MessageHistory[]) { return history; } -export default function HistoryButton({ onClick }: Props) { +export default function HistoryButton({ disabled, onClick }: Props) { const [chatHistory, setChatHistory] = useRecoilState(chatHistoryState); const ref = useRef(); @@ -239,6 +240,7 @@ export default function HistoryButton({ onClick }: Props) { toggleChatHistoryMenu(!chatHistory.open)} ref={ref} > diff --git a/frontend/src/components/organisms/chat/index.tsx b/frontend/src/components/organisms/chat/index.tsx index 4578b5d1ee..27f643a971 100644 --- a/frontend/src/components/organisms/chat/index.tsx +++ b/frontend/src/components/organisms/chat/index.tsx @@ -1,10 +1,19 @@ import { useCallback, useEffect, useState } from 'react'; -import { useRecoilValue, useSetRecoilState } from 'recoil'; +import toast from 'react-hot-toast'; +import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'; import { v4 as uuidv4 } from 'uuid'; -import { Alert, Box } from '@mui/material'; +import { UploadFile } from '@mui/icons-material'; +import { Alert, Box, Stack, Typography } from '@mui/material'; -import { ErrorBoundary, IMessage, useChat } from '@chainlit/components'; +import { + ErrorBoundary, + IFileElement, + IFileResponse, + IMessage, + useChat, + useUpload +} from '@chainlit/components'; import SideView from 'components/atoms/element/sideView'; import { Logo } from 'components/atoms/logo'; @@ -13,6 +22,7 @@ import TaskList from 'components/molecules/tasklist'; import { useAuth } from 'hooks/auth'; +import { attachmentsState } from 'state/chat'; import { chatHistoryState } from 'state/chatHistory'; import { conversationsHistoryState } from 'state/conversations'; import { projectSettingsState, sideViewState } from 'state/project'; @@ -21,11 +31,13 @@ import InputBox from './inputBox'; import MessageContainer from './message/container'; const Chat = () => { - const { user } = useAuth(); const pSettings = useRecoilValue(projectSettingsState); - const sideViewElement = useRecoilValue(sideViewState); + const setAttachments = useSetRecoilState(attachmentsState); const setChatHistory = useSetRecoilState(chatHistoryState); const setConversations = useSetRecoilState(conversationsHistoryState); + const sideViewElement = useRecoilValue(sideViewState); + + const { user } = useAuth(); const [autoScroll, setAutoScroll] = useState(true); const { @@ -39,9 +51,32 @@ const Chat = () => { elements, askUser, avatars, - loading + loading, + disabled } = useChat(); + const fileSpec = { max_size_mb: 20 }; + const onFileUpload = (payloads: IFileResponse[]) => { + const fileElements = payloads.map((file) => ({ + id: uuidv4(), + type: 'file' as const, + display: 'inline' as const, + name: file.name, + mime: file.type, + content: file.content + })); + setAttachments((prev) => prev.concat(fileElements)); + }; + + const onFileUploadError = (error: string) => toast.error(error); + + const upload = useUpload({ + spec: fileSpec, + onResolved: onFileUpload, + onError: onFileUploadError, + options: { noClick: true } + }); + useEffect(() => { setConversations((prev) => ({ ...prev, @@ -50,7 +85,7 @@ const Chat = () => { }, []); const onSubmit = useCallback( - async (msg: string) => { + async (msg: string, files?: IFileElement[]) => { const message: IMessage = { id: uuidv4(), author: user?.username || 'User', @@ -77,7 +112,7 @@ const Chat = () => { }); setAutoScroll(true); - sendMessage(message); + sendMessage(message, files); }, [user, pSettings, sendMessage] ); @@ -99,11 +134,47 @@ const Chat = () => { ); const tasklist = tasklists.at(-1); + const enableMultiModalUpload = !disabled && pSettings?.features.multi_modal; return ( - + + {upload ? ( + <> + + {upload?.isDragActive ? ( + theme.palette.primary.main, + color: 'white', + height: '100%', + width: '100%', + opacity: 0.9, + zIndex: 10, + alignItems: 'center', + justifyContent: 'center' + }} + > + + Drop your files here! + + ) : null} + + ) : null} + {error && ( @@ -123,7 +194,13 @@ const Chat = () => { callAction={callAction} setAutoScroll={setAutoScroll} /> - + { objectFit: 'contain', position: 'absolute', pointerEvents: 'none', - top: '45%', + top: '40%', left: '50%', transform: 'translate(-50%, -50%)', filter: 'grayscale(100%)', diff --git a/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx b/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx new file mode 100644 index 0000000000..cd8fa72856 --- /dev/null +++ b/frontend/src/components/organisms/chat/inputBox/UploadButton.tsx @@ -0,0 +1,51 @@ +import { useRecoilValue } from 'recoil'; + +import Add from '@mui/icons-material/Add'; +import { IconButton, Tooltip } from '@mui/material'; + +import { FileSpec, IFileResponse } from '@chainlit/components'; +import { useUpload } from '@chainlit/components'; + +import { projectSettingsState } from 'state/project'; + +type Props = { + disabled?: boolean; + fileSpec: FileSpec; + onFileUpload: (files: IFileResponse[]) => void; + onFileUploadError: (error: string) => void; +}; + +const UploadButton = ({ + disabled, + fileSpec, + onFileUpload, + onFileUploadError +}: Props) => { + const pSettings = useRecoilValue(projectSettingsState); + + const upload = useUpload({ + spec: fileSpec, + onResolved: (payloads: IFileResponse[]) => onFileUpload(payloads), + onError: onFileUploadError, + options: { noDrag: true } + }); + + if (!upload || !pSettings?.features.multi_modal) return null; + const { getRootProps, getInputProps, uploading } = upload; + + return ( + + + + + + + ); +}; + +export default UploadButton; diff --git a/frontend/src/components/organisms/chat/inputBox/index.tsx b/frontend/src/components/organisms/chat/inputBox/index.tsx index 354a8dc748..f5a090781c 100644 --- a/frontend/src/components/organisms/chat/inputBox/index.tsx +++ b/frontend/src/components/organisms/chat/inputBox/index.tsx @@ -1,15 +1,26 @@ import { Box } from '@mui/material'; +import { FileSpec, IFileElement, IFileResponse } from '@chainlit/components'; + import StopButton from '../stopButton'; import Input from './input'; import WaterMark from './waterMark'; interface Props { - onSubmit: (message: string) => void; + fileSpec: FileSpec; + onFileUpload: (payload: IFileResponse[]) => void; + onFileUploadError: (error: string) => void; + onSubmit: (message: string, files?: IFileElement[]) => void; onReply: (message: string) => void; } -export default function InputBox({ onSubmit, onReply }: Props) { +export default function InputBox({ + fileSpec, + onFileUpload, + onFileUploadError, + onSubmit, + onReply +}: Props) { // const tokenCount = useRecoilValue(tokenCountState); return ( @@ -29,7 +40,13 @@ export default function InputBox({ onSubmit, onReply }: Props) { > - + {/* {tokenCount > 0 && ( */} {/* void; + fileSpec: FileSpec; + onFileUpload: (payload: IFileResponse[]) => void; + onFileUploadError: (error: string) => void; + onSubmit: (message: string, files?: IFileElement[]) => void; onReply: (message: string) => void; } @@ -28,20 +39,64 @@ function getLineCount(el: HTMLDivElement) { return lines.length; } -const Input = ({ onSubmit, onReply }: Props) => { - const ref = useRef(null); +const Input = ({ + fileSpec, + onFileUpload, + onFileUploadError, + onSubmit, + onReply +}: Props) => { + const [fileElements, setFileElements] = useRecoilState(attachmentsState); const setChatHistory = useSetRecoilState(chatHistoryState); const setChatSettingsOpen = useSetRecoilState(chatSettingsOpenState); - const { connected, loading, askUser, chatSettingsInputs } = useChat(); + + const ref = useRef(null); + const { loading, askUser, chatSettingsInputs, disabled } = useChat(); const [value, setValue] = useState(''); const [isComposing, setIsComposing] = useState(false); - const disabled = - !connected || - loading || - askUser?.spec.type === 'file' || - askUser?.spec.type === 'action'; + useEffect(() => { + const pasteEvent = (event: ClipboardEvent) => { + if (event.clipboardData && event.clipboardData.items) { + const items = Array.from(event.clipboardData.items); + items.forEach((item) => { + if (item.kind === 'file') { + const file = item.getAsFile(); + if (file) { + const reader = new FileReader(); + reader.onload = function (e) { + const content = e.target?.result as ArrayBuffer; + if (content) { + onFileUpload([ + { + name: file.name, + type: file.type, + content, + size: file.size + } + ]); + } + }; + reader.readAsArrayBuffer(file); + } + } + }); + } + }; + + if (!ref.current) { + return; + } + + const input = ref.current; + + input.addEventListener('paste', pasteEvent); + + return () => { + input.removeEventListener('paste', pasteEvent); + }; + }, []); useEffect(() => { if (ref.current && !loading && !disabled) { @@ -56,10 +111,19 @@ const Input = ({ onSubmit, onReply }: Props) => { if (askUser) { onReply(value); } else { - onSubmit(value); + onSubmit(value, fileElements); } + setFileElements([]); setValue(''); - }, [value, disabled, setValue, askUser, onSubmit]); + }, [ + value, + disabled, + setValue, + askUser, + fileElements, + setFileElements, + onSubmit + ]); const handleCompositionStart = () => { setIsComposing(true); @@ -94,6 +158,7 @@ const Input = ({ onSubmit, onReply }: Props) => { const startAdornment = ( <> + {chatSettingsInputs.length > 0 && ( { )} - - + ); @@ -116,46 +185,12 @@ const Input = ({ onSubmit, onReply }: Props) => { ); return ( - setValue(e.target.value)} - onKeyDown={handleKeyDown} - onCompositionStart={handleCompositionStart} - onCompositionEnd={handleCompositionEnd} - value={value} - fullWidth - InputProps={{ - disableUnderline: true, - startAdornment: ( - - {startAdornment} - - ), - endAdornment: ( - - {endAdornment} - - ) - }} + `1px solid ${theme.palette.divider}`, boxShadow: 'box-shadow: 0px 2px 4px 0px #0000000D', - textarea: { height: '34px', maxHeight: '30vh', @@ -167,7 +202,58 @@ const Input = ({ onSubmit, onReply }: Props) => { lineHeight: '24px' } }} - /> + > + {fileElements.length > 0 ? ( + + + + ) : null} + + setValue(e.target.value)} + onKeyDown={handleKeyDown} + onCompositionStart={handleCompositionStart} + onCompositionEnd={handleCompositionEnd} + value={value} + fullWidth + InputProps={{ + disableUnderline: true, + startAdornment: ( + + {startAdornment} + + ), + endAdornment: ( + + {endAdornment} + + ) + }} + /> + ); }; diff --git a/frontend/src/components/organisms/chat/message/UploadButton.tsx b/frontend/src/components/organisms/chat/message/UploadButton.tsx deleted file mode 100644 index 1af530d003..0000000000 --- a/frontend/src/components/organisms/chat/message/UploadButton.tsx +++ /dev/null @@ -1,47 +0,0 @@ -import Add from '@mui/icons-material/Add'; -import { LoadingButton } from '@mui/lab'; -import { Tooltip } from '@mui/material'; - -import { FileSpec, IFileResponse, useChat } from '@chainlit/components'; -import { useUpload } from '@chainlit/components'; - -type UploadChildProps = { - fileSpec: FileSpec; - uploadFiles: (files: IFileResponse[]) => void; -}; - -const UploadChildButton = ({ fileSpec, uploadFiles }: UploadChildProps) => { - const upload = useUpload({ - spec: fileSpec, - onResolved: (payloads: IFileResponse[]) => uploadFiles(payloads) - }); - - if (!upload) return null; - const { getRootProps, getInputProps, uploading } = upload; - - return ( - - - - - - - ); -}; - -export default function UploadButton() { - const { fileSpec, uploadFiles } = useChat(); - - if (!fileSpec) return null; - - return ; -} diff --git a/frontend/src/components/organisms/chat/message/container.tsx b/frontend/src/components/organisms/chat/message/container.tsx index 3d2d7494d3..d6aeffa659 100644 --- a/frontend/src/components/organisms/chat/message/container.tsx +++ b/frontend/src/components/organisms/chat/message/container.tsx @@ -141,7 +141,8 @@ const MessageContainer = ({ uiName: projectSettings?.ui?.name || '', onPlaygroundButtonClick, onFeedbackUpdated, - onElementRefClick + onElementRefClick, + onError: (error) => toast.error(error) }} /> ); diff --git a/frontend/src/state/chat.ts b/frontend/src/state/chat.ts new file mode 100644 index 0000000000..8d4532b206 --- /dev/null +++ b/frontend/src/state/chat.ts @@ -0,0 +1,8 @@ +import { atom } from 'recoil'; + +import { IFileElement } from '@chainlit/components'; + +export const attachmentsState = atom({ + key: 'Attachments', + default: [] +}); diff --git a/frontend/src/state/project.ts b/frontend/src/state/project.ts index 3111286ee8..0a2f81842e 100644 --- a/frontend/src/state/project.ts +++ b/frontend/src/state/project.ts @@ -18,6 +18,9 @@ export interface IProjectSettings { default_expand_messages?: boolean; github?: string; }; + features: { + multi_modal?: boolean; + }; userEnv: string[]; dataPersistence: boolean; chatProfiles: ChatProfile[]; diff --git a/frontend/src/types/chat.ts b/frontend/src/types/chat.ts index c78ed0610b..141a946ba5 100644 --- a/frontend/src/types/chat.ts +++ b/frontend/src/types/chat.ts @@ -1,8 +1,7 @@ import { Socket } from 'socket.io-client'; -import { IMessage } from '@chainlit/components'; +import { IMessage, IMessageElement } from '@chainlit/components'; -import { IMessageElement } from './element'; import { IAppUser } from './user'; export interface IChat { diff --git a/frontend/src/types/element.ts b/frontend/src/types/element.ts deleted file mode 100644 index 8a2569e0e0..0000000000 --- a/frontend/src/types/element.ts +++ /dev/null @@ -1,70 +0,0 @@ -export type IElement = - | IImageElement - | ITextElement - | IPdfElement - | IAvatarElement - | ITasklistElement - | IAudioElement - | IVideoElement - | IFileElement; - -export type IMessageElement = - | IImageElement - | ITextElement - | IPdfElement - | IAudioElement - | IVideoElement - | IFileElement; - -export type ElementType = IElement['type']; -export type IElementSize = 'small' | 'medium' | 'large'; - -interface TElement { - id: string; - type: T; - conversationId?: string; - forIds?: string[]; - url?: string; -} - -interface TMessageElement extends TElement { - name: string; - display: 'inline' | 'side' | 'page'; -} - -export interface IImageElement extends TMessageElement<'image'> { - content?: ArrayBuffer; - size?: IElementSize; -} - -export interface IAvatarElement extends TElement<'avatar'> { - name: string; - content?: ArrayBuffer; -} - -export interface ITextElement extends TMessageElement<'text'> { - content?: string; - language?: string; -} - -export interface IPdfElement extends TMessageElement<'pdf'> { - content?: string; -} - -export interface IAudioElement extends TMessageElement<'audio'> { - content?: ArrayBuffer; -} - -export interface IVideoElement extends TMessageElement<'video'> { - content?: ArrayBuffer; - size?: IElementSize; -} - -export interface IFileElement extends TMessageElement<'file'> { - type: 'file'; - content?: ArrayBuffer; -} - -export interface ITasklistElement extends TElement<'tasklist'> { - content?: string; -} diff --git a/libs/components/contexts/MessageContext.tsx b/libs/components/contexts/MessageContext.tsx index 213e75bccc..a557858f9e 100644 --- a/libs/components/contexts/MessageContext.tsx +++ b/libs/components/contexts/MessageContext.tsx @@ -13,6 +13,7 @@ const defaultMessageContext = { onFeedbackUpdated: undefined, onPlaygroundButtonClick: undefined, showFeedbackButtons: true, + onError: () => undefined, uiName: '' }; diff --git a/libs/components/hooks/useChat/index.ts b/libs/components/hooks/useChat/index.ts index b856a110a9..a83de2650b 100644 --- a/libs/components/hooks/useChat/index.ts +++ b/libs/components/hooks/useChat/index.ts @@ -8,7 +8,7 @@ import { } from 'recoil'; import io from 'socket.io-client'; import { TFormInput } from 'src/inputs'; -import { IAction, IElement, IFileResponse, IMessage } from 'src/types'; +import { IAction, IElement, IFileElement, IMessage } from 'src/types'; import { nestMessages } from 'utils/message'; import { @@ -19,7 +19,6 @@ import { chatSettingsInputsState, chatSettingsValueState, elementState, - fileSpecState, firstUserMessageState, loadingState, messagesState, @@ -52,7 +51,6 @@ const useChat = () => { const [session, setSession] = useRecoilState(sessionState); const [loading, setLoading] = useRecoilState(loadingState); const [rawMessages, setMessages] = useRecoilState(messagesState); - const [fileSpec, setFileSpecState] = useRecoilState(fileSpecState); const [askUser, setAskUser] = useRecoilState(askUserState); const [elements, setElements] = useRecoilState(elementState); const [avatars, setAvatars] = useRecoilState(avatarState); @@ -204,11 +202,6 @@ const useChat = () => { setLoading(false); }); - socket.on('enable_file_upload', (spec) => { - setFileSpecState(spec); - setLoading(false); - }); - socket.on('ask_timeout', () => { setAskUser(undefined); setLoading(false); @@ -306,9 +299,9 @@ const useChat = () => { }, [session]); const sendMessage = useCallback( - (message: IMessage) => { + (message: IMessage, files?: IFileElement[]) => { setMessages((oldMessages) => [...oldMessages, message]); - session?.socket.emit('ui_message', message); + session?.socket.emit('ui_message', { message, files }); }, [session] ); @@ -342,15 +335,13 @@ const useChat = () => { [session] ); - const uploadFiles = useCallback( - (payloads: IFileResponse[]) => { - session?.socket.emit('file_upload', payloads); - }, - [session] - ); - const messages = nestMessages(rawMessages); const connected = session?.socket.connected && !session?.error; + const disabled = + !connected || + loading || + askUser?.spec.type === 'file' || + askUser?.spec.type === 'action'; return { connect, @@ -360,9 +351,9 @@ const useChat = () => { updateChatSettings, stopTask, callAction, - uploadFiles, replyMessage, connected, + disabled, error: session?.error, loading, messages, @@ -373,7 +364,6 @@ const useChat = () => { chatSettingsInputs, chatSettingsValue, chatSettingsDefaultValue, - fileSpec, askUser, firstUserMessage }; diff --git a/libs/components/hooks/useChat/state.ts b/libs/components/hooks/useChat/state.ts index a1ba754a45..078fd54988 100644 --- a/libs/components/hooks/useChat/state.ts +++ b/libs/components/hooks/useChat/state.ts @@ -2,7 +2,6 @@ import { DefaultValue, atom, selector } from 'recoil'; import { Socket } from 'socket.io-client'; import { TFormInput } from 'src/inputs'; import { - FileSpec, IAction, IAsk, IAvatarElement, @@ -56,11 +55,6 @@ export const loadingState = atom({ default: false }); -export const fileSpecState = atom({ - key: 'FileSpec', - default: undefined -}); - export const askUserState = atom({ key: 'AskUser', default: undefined diff --git a/libs/components/hooks/useUpload.tsx b/libs/components/hooks/useUpload.tsx index 75765f54c2..98dded8f71 100644 --- a/libs/components/hooks/useUpload.tsx +++ b/libs/components/hooks/useUpload.tsx @@ -6,15 +6,16 @@ import { useDropzone } from 'react-dropzone'; -import { FileSpec, IFileResponse } from '../src/types/file'; +import { FileSpec, IFileResponse } from 'src/types/file'; interface useUploadProps { - onResolved: (payloads: IFileResponse[]) => void; onError?: (error: string) => void; + onResolved: (payloads: IFileResponse[]) => void; + options?: DropzoneOptions; spec: FileSpec; } -const useUpload = ({ onResolved, spec, onError }: useUploadProps) => { +const useUpload = ({ onError, onResolved, options, spec }: useUploadProps) => { const [uploading, setUploading] = useState(false); const onDrop: DropzoneOptions['onDrop'] = useCallback( @@ -62,8 +63,6 @@ const useUpload = ({ onResolved, spec, onError }: useUploadProps) => { [spec] ); - if (!spec.accept || !spec.max_size_mb) return null; - let dzAccept: Record = {}; const accept = spec.accept; @@ -77,14 +76,15 @@ const useUpload = ({ onResolved, spec, onError }: useUploadProps) => { dzAccept = accept; } - const { getRootProps, getInputProps } = useDropzone({ + const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, maxFiles: spec.max_files || 1, accept: dzAccept, - maxSize: spec.max_size_mb * 1000000 + maxSize: (spec.max_size_mb || 2) * 1000000, + ...options }); - return { getRootProps, getInputProps, uploading }; + return { getInputProps, getRootProps, isDragActive, uploading }; }; export { useUpload }; diff --git a/libs/components/package.json b/libs/components/package.json index 44b6816f30..a76ca2b02e 100644 --- a/libs/components/package.json +++ b/libs/components/package.json @@ -1,7 +1,7 @@ { "name": "@chainlit/components", "description": "Reusable components of the Chainlit UI.", - "version": "0.0.8", + "version": "0.0.9", "scripts": { "build": "tsup src/index.ts --clean --format esm,cjs --dts --external react --legacy-output --minify --sourcemap --treeshake", "build:watch": "tsup src/index.ts --watch --clean --format esm,cjs --dts --external react --legacy-output --minify --sourcemap --treeshake", @@ -34,6 +34,7 @@ "@testing-library/react": "^14.0.0", "@types/draft-js": "^0.11.10", "@types/lodash": "^4.14.199", + "@types/react-file-icon": "^1.0.2", "@types/react-resizable": "^3.0.4", "@types/react-syntax-highlighter": "^15.5.7", "@types/uuid": "^9.0.3", @@ -67,6 +68,7 @@ "lodash": "^4.17.21", "mui-chips-input": "2.0.0", "react-dropzone": "^14.2.3", + "react-file-icon": "^1.3.0", "react-markdown": "^8.0.7", "react-password-checklist": "^1.4.3", "react-resizable": "^3.0.5", diff --git a/libs/components/pnpm-lock.yaml b/libs/components/pnpm-lock.yaml index 2fa0f4a74f..d45af66823 100644 --- a/libs/components/pnpm-lock.yaml +++ b/libs/components/pnpm-lock.yaml @@ -44,6 +44,9 @@ dependencies: react-dropzone: specifier: ^14.2.3 version: 14.2.3(react@18.2.0) + react-file-icon: + specifier: ^1.3.0 + version: 1.3.0(react-dom@18.2.0)(react@18.2.0) react-markdown: specifier: ^8.0.7 version: 8.0.7(@types/react@18.2.0)(react@18.2.0) @@ -94,6 +97,9 @@ devDependencies: '@types/lodash': specifier: ^4.14.199 version: 4.14.199 + '@types/react-file-icon': + specifier: ^1.0.2 + version: 1.0.2 '@types/react-resizable': specifier: ^3.0.4 version: 3.0.4 @@ -1287,6 +1293,12 @@ packages: '@types/react': 18.2.0 dev: true + /@types/react-file-icon@1.0.2: + resolution: {integrity: sha512-mK6+83GNuFasalMpuwRvYzusczMlqkw/O1hRMXv3NikpSa3PeFIT9Yee2mZfA2ukZ/MZZ0VIWc0fqn/ocqPmRQ==} + dependencies: + '@types/react': 18.2.0 + dev: true + /@types/react-resizable@3.0.4: resolution: {integrity: sha512-+QguN9CDfC1lthq+4noG1fkxh8cqkV2Fv/Mu3mdknCCBiwwNLecnBdk1MmNNN7uJpT23Nx/aVkYsbt5NuWouFw==} dependencies: @@ -1721,6 +1733,10 @@ packages: resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} dev: true + /colord@2.9.3: + resolution: {integrity: sha512-jeC1axXpnb0/2nn/Y1LPuLdgXBLH7aDcHu4KEKfqw3CUhX7ZpfBSlPKyqXE6btIgEzfWtrX3/tyBCaCvXvMkOw==} + dev: false + /combined-stream@1.0.8: resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} engines: {node: '>= 0.8'} @@ -2759,6 +2775,10 @@ packages: resolution: {integrity: sha512-HDWXG8isMntAyRF5vZ7xKuEvOhT4AhlRt/3czTSjvGUxjYCBVRQY48ViDHyfYz9VIoBkW4TMGQNapx+l3RUwdA==} dev: true + /lodash.uniqueid@4.0.1: + resolution: {integrity: sha512-GQQWaIeGlL6DIIr06kj1j6sSmBxyNMwI8kaX9aKpHR/XsMTiaXDVPNPAkiboOTK9OJpTJF/dXT3xYoFQnj386Q==} + dev: false + /lodash@4.17.21: resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} @@ -3584,6 +3604,19 @@ packages: resolution: {integrity: sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw==} dev: false + /react-file-icon@1.3.0(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-wxl/WwSX5twQKVXloPHbS71iZQUKO84KgZ44Kh7vYZGu1qH2kagx+RSTNfk/+IHtXfjPWPNIHPGi2Y8S94N1CQ==} + peerDependencies: + react: ^18.0.0 || ^17.0.0 || ^16.2.0 + react-dom: ^18.0.0 || ^17.0.0 || ^16.2.0 + dependencies: + colord: 2.9.3 + lodash.uniqueid: 4.0.1 + prop-types: 15.8.1 + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /react-is@16.13.1: resolution: {integrity: sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==} dev: false diff --git a/libs/components/src/Attachments.tsx b/libs/components/src/Attachments.tsx new file mode 100644 index 0000000000..e4a96562f9 --- /dev/null +++ b/libs/components/src/Attachments.tsx @@ -0,0 +1,48 @@ +import React from 'react'; + +import Stack from '@mui/material/Stack'; + +import { FileElement } from './elements'; +import { IFileElement } from './types'; + +interface AttachmentsProps { + fileElements: IFileElement[]; + setFileElements: React.Dispatch>; +} + +const Attachments = ({ + fileElements, + setFileElements +}: AttachmentsProps): JSX.Element => { + if (fileElements.length === 0) return <>; + + const onRemove = (index: number) => { + setFileElements((prev) => + prev.filter((_, prevIndex) => index !== prevIndex) + ); + }; + + return ( + + {fileElements.map((fileElement, index) => { + return ( + onRemove(index)} + /> + ); + })} + + ); +}; + +export { Attachments }; diff --git a/libs/components/src/elements/File.tsx b/libs/components/src/elements/File.tsx index 3cde13690c..d2a2050c3c 100644 --- a/libs/components/src/elements/File.tsx +++ b/libs/components/src/elements/File.tsx @@ -1,29 +1,138 @@ -import { GreyButton } from 'src/buttons/GreyButton'; +import { useState } from 'react'; +import { FileIcon, defaultStyles } from 'react-file-icon'; -import AttachFile from '@mui/icons-material/AttachFile'; +import Close from '@mui/icons-material/Close'; +import Box from '@mui/material/Box'; +import IconButton from '@mui/material/IconButton'; import Link from '@mui/material/Link'; +import Stack from '@mui/material/Stack'; +import Typography from '@mui/material/Typography'; import { IFileElement } from 'src/types/element'; -const FileElement = ({ element }: { element: IFileElement }) => { +const FileElement = ({ + element, + onRemove +}: { + element: IFileElement; + onRemove?: () => void; +}) => { + const [isHovered, setIsHovered] = useState(false); + if (!element.url && !element.content) { return null; } - return ( - } - href={element.url || URL.createObjectURL(new Blob([element.content!]))} - LinkComponent={({ ...props }) => ( - - )} + let children; + const mime = element.mime ? element.mime.split('/').pop()! : 'file'; + + if (element.mime?.includes('image') && !element.mime?.includes('svg')) { + children = ( + + ); + } else { + children = ( + `1px solid ${theme.palette.primary.main}`, + color: (theme) => + theme.palette.mode === 'light' + ? theme.palette.primary.main + : theme.palette.text.primary, + background: (theme) => + theme.palette.mode === 'light' + ? theme.palette.primary.light + : theme.palette.primary.dark + }} + > + svg': { + height: '30px' + } + }} + > + + + + {element.name} + + + ); + } + + const fileElement = ( + setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + height={50} > - {element.name} - + {isHovered && onRemove ? ( + `1px solid ${theme.palette.divider}`, + '&:hover': { + backgroundColor: 'background.default' + } + }} + onClick={onRemove} + > + + + ) : null} + {children} + ); + + if (!onRemove) { + return ( + + {fileElement} + + ); + } else { + return fileElement; + } }; export { FileElement }; diff --git a/libs/components/src/index.ts b/libs/components/src/index.ts index 82e5f71945..b3a71ff4a5 100644 --- a/libs/components/src/index.ts +++ b/libs/components/src/index.ts @@ -8,9 +8,10 @@ export * from './types'; export * from 'theme/index'; export * from 'hooks/index'; -export { nestMessages } from 'utils/message'; -export { ErrorBoundary } from './ErrorBoundary'; +export { Attachments } from './Attachments'; export { ClipboardCopy } from './ClipboardCopy'; export { Code } from './Code'; -export { NotificationCount } from './NotificationCount'; export { Collapse } from './Collapse'; +export { ErrorBoundary } from './ErrorBoundary'; +export { nestMessages } from 'utils/message'; +export { NotificationCount } from './NotificationCount'; diff --git a/libs/components/src/messages/Message.tsx b/libs/components/src/messages/Message.tsx index 038dc7c939..e398ce2e4e 100644 --- a/libs/components/src/messages/Message.tsx +++ b/libs/components/src/messages/Message.tsx @@ -104,7 +104,7 @@ const Message = ({ loading={isRunning} /> {!isRunning && isLast && message.waitForAnswer && ( - + )} diff --git a/libs/components/src/messages/components/AskUploadButton.tsx b/libs/components/src/messages/components/AskUploadButton.tsx index fc541dcd37..4cc411ed35 100644 --- a/libs/components/src/messages/components/AskUploadButton.tsx +++ b/libs/components/src/messages/components/AskUploadButton.tsx @@ -10,10 +10,17 @@ import { useUpload } from 'hooks/useUpload'; import { IAsk, IFileResponse } from 'src/types/file'; -const AskUploadChildButton = ({ askUser }: { askUser: IAsk }) => { +const AskUploadChildButton = ({ + askUser, + onError +}: { + askUser: IAsk; + onError: (error: string) => void; +}) => { const upload = useUpload({ spec: askUser.spec, - onResolved: (payloads: IFileResponse[]) => askUser?.callback(payloads) + onResolved: (payloads: IFileResponse[]) => askUser?.callback(payloads), + onError: (error: string) => onError(error) }); if (!upload) return null; @@ -32,7 +39,7 @@ const AskUploadChildButton = ({ askUser }: { askUser: IAsk }) => { padding={2} {...getRootProps({ className: 'dropzone' })} > - + Drag and drop files here @@ -41,7 +48,7 @@ const AskUploadChildButton = ({ askUser }: { askUser: IAsk }) => { { ); }; -const AskUploadButton = () => { +const AskUploadButton = ({ onError }: { onError: (error: string) => void }) => { const messageContext = useContext(MessageContext); if (messageContext.askUser?.spec.type !== 'file') return null; - return ; + return ( + + ); }; export { AskUploadButton }; diff --git a/libs/components/src/messages/components/ElementRef.tsx b/libs/components/src/messages/components/ElementRef.tsx index 9801679ea2..78f0252096 100644 --- a/libs/components/src/messages/components/ElementRef.tsx +++ b/libs/components/src/messages/components/ElementRef.tsx @@ -20,6 +20,7 @@ const ElementRef = ({ element }: Props) => { onElementRefClick && onElementRefClick(element)} > {element.name} diff --git a/libs/components/src/types/element.ts b/libs/components/src/types/element.ts index 8a2569e0e0..9915d1dc32 100644 --- a/libs/components/src/types/element.ts +++ b/libs/components/src/types/element.ts @@ -62,6 +62,7 @@ export interface IVideoElement extends TMessageElement<'video'> { export interface IFileElement extends TMessageElement<'file'> { type: 'file'; + mime?: string; content?: ArrayBuffer; } diff --git a/libs/components/src/types/message.ts b/libs/components/src/types/message.ts index 008dbb9341..a7f9dd9915 100644 --- a/libs/components/src/types/message.ts +++ b/libs/components/src/types/message.ts @@ -1,4 +1,4 @@ -import { IMessageElement } from './element'; +import { IFileElement, IMessageElement } from './element'; interface IBaseTemplate { template?: string; @@ -24,21 +24,22 @@ export interface IPrompt extends IBaseTemplate { } export interface IMessage { - id: string; author: string; authorIsUser?: boolean; - waitForAnswer?: boolean; content?: string; createdAt: number | string; + disableHumanFeedback?: boolean; + elements?: IFileElement[]; humanFeedback?: number; humanFeedbackComment?: string; - disableHumanFeedback?: boolean; - language?: string; + id: string; indent?: number; - parentId?: string; isError?: boolean; + language?: string; + parentId?: string; prompt?: IPrompt; streaming?: boolean; + waitForAnswer?: boolean; } export interface INestedMessage extends IMessage { diff --git a/libs/components/src/types/messageContext.ts b/libs/components/src/types/messageContext.ts index 71db3d4248..f0bbb3318a 100644 --- a/libs/components/src/types/messageContext.ts +++ b/libs/components/src/types/messageContext.ts @@ -20,6 +20,7 @@ interface IMessageContext { onSuccess: () => void, feedbackComment?: string ) => void; + onError: (error: string) => void; } export type { IMessageContext };